import matplotlib as mpl
#mpl.use('Agg') # necessary to be able to generate the reports even without X server
from matplotlib.pylab import plt
import numpy as np
from .oivisquery import query
from .log import Log, ERROR, WARNING, NOTICE
log = Log()


def plot_cursor(cursor, order=["x","y"], data=None,
                colorDict={},
                axesDict={}, markerDict={}, areaDict={},
                scater=True,
                color_key=None, marker_key= None,
                axes_key= None, area_key= None
                ):

    axes = axesDict.get( axes_key, None)
    if axes is None: axes = plt.gca()  #get curent axes

    color  =  colorDict.get( color_key , "k" )
    marker = markerDict.get( marker_key, "*" )
    area   = areaDict.get( area_key, 60)

    data = data or {}
    data.update(list(zip(order,list(zip(*cursor)))))


    if not len(data):
        return axes

    if scater:
        if "cololors" in data:
            colors = data["colors"]
        elif "colors_key" in data: # take the color
            colors = np.array([colorDict.get(cid,"k") for cid in data["colors_key"]])
        else:
            colors = color

        if "areas" in data:
            areas = data["areas"]

        elif "areas_key" in data:
            areas = np.array([areaDict.get(mid,"k") for mid in data["areas_key"]])
        else:
            areas = area

        axes.scatter(data["x"], data["y"], areas, c=colors, marker=marker, linewidths=0)
        erb = axes.errorbar(data["x"], data["y"],
                           xerr=data.get("xerr",None), yerr=data.get("yerr",None),
                           color="gray", fmt='none', linestyle='None', alpha=0.5, ecolor="k")
        for p  in erb: p.set_alpha(0.3)

    else:
        axes.errorbar(data["x"], data["y"],
                      xerr=data.get("xerr",None), yerr=data.get("yerr",None),
                      color=color, marker=marker, linestyle='None')
    return axes



def gb(bnames, *arrname):
    if isinstance(bnames, tuple):
        name, replacement = bnames
    else:
        name, replacement = bnames, bnames
    tbls =  "%s as %s"%(name, replacement)
    if not arrname: return  tbls, ""

    stai = list(range(1, len(arrname)+1)) if len(arrname)>1 else [""]
    return (tbls+","+(",".join("array as %s"%r for r in arrname )),
            " AND ".join( "%s.STA%s_INDEX = %s.STA_INDEX"%(replacement, i, r) for i,r in zip(stai, arrname)))

def gbq(bnames, *arrname):
    tbls, conds = gb(bnames, *arrname)
    return query(None, tbls=tbls, conds=conds)




def qt(tblname):
    return  query(None, ## this one needs a table named as base
                   tbls= "arrays as ar1, arrays as ar2",
                   conds="{0}.STA1_INDEX<100 AND {0}.STA1_INDEX=ar1.STA_INDEX AND ar2.STA_INDEX={0}.STA2_INDEX".format(tblname)
                   )


q_2station = query(None, ## this one needs a table named as base
                   tbls= "arrays as ar1, arrays as ar2",
                   conds="base.STA1_INDEX<100 AND base.STA1_INDEX=ar1.STA_INDEX AND ar2.STA_INDEX=base.STA2_INDEX"
                   )
q_3station = query(None, ## this one needs a table named as base
                   tbls= "arrays as ar1, arrays as ar2, arrays as ar3",
                   conds="triplet.STA1_INDEX<100 AND triplet.STA1_INDEX=ar1.STA_INDEX AND triplet.STA2_INDEX=ar2.STA_INDEX AND triplet.STA3_INDEX=ar3.STA_INDEX"
                   )

loop_setups = query(None,
                    "[ESO INS SPEC RES],[ESO INS POLA MODE],[ESO INS TIM2 PERIOD],[ESO DET2 SEQ1 DIT]",
                    "headers"
                    )

loop_bases = q_2station.query(columns="ar1.STA_NAME, ar2.STA_NAME")


loop_triplets = q_3station.query(
                        columns="ar1.STA_NAME, ar2.STA_NAME, ar3.STA_NAME"
                    )

loop_target = query(None,
                     "TARGET", "targets"
                    )


def get_figure(**kwargs):
    fig = kwargs.get("fig", None)
    if fig:
        return fig
    fig_key = kwargs.get("fig_key", None)
    if not fig_key:
        return plt.gcf()

    figDict = kwargs.get("figDict", {})
    fig = figDict.get(fig_key,None)
    return fig or  plt.gcf()


def plot_vis2(db, insname='', fig=None, **kwargs):
    fig = get_figure(fig=fig, **kwargs)

    axm = axes_maker(fig,
                    sharex=True,
                    xlabel="MJD", ylabel="$vis^2$",
                    ylim=(-0.2,1.2)
                    )

    q = query(db, "MJD,mean(VIS2DATA,WImin,WImax),mean(VIS2ERR,WImin,WImax),TARGET",
                  "vis2 as base, waveindex, targets",  #tables
                  "base.INSNAME LIKE '%{insname}%'".format(insname=insname),
                  automatch=True
              )

    return plot_doublet(q, ["x", "y", "yerr", "colors_key"],
                        loop_bases = loop_bases.query(tbls="vis2 as base"),
                        axes_maker = axm, **kwargs
                )

def plot_visAmp(db, insname='', fig=None, **kwargs):
    fig = get_figure(fig=fig, **kwargs)

    axm = axes_maker(fig,
                    sharex=True,
                    xlabel="MJD", ylabel="$vis^2$",
                    ylim=(-0.2,1.2)
                    )

    q = query(db, "MJD,mean(VISAMP,WImin,WImax),mean(VISAMPERR,WImin,WImax),TARGET",
                  "vis as base, waveindex, targets",  #tables
                  "base.INSNAME LIKE '%{insname}%'".format(insname=insname),
                  automatch=True
             )
    return plot_doublet(q, ["x", "y", "yerr", "colors_key"],
                        loop_bases = loop_bases.query(tbls="vis2 as base"),
                        axes_maker = axm, **kwargs
                )








def plot_doublet(query, columns, orders=["x", "y", "yerr"],
                 fig=None, axes_maker=None, colorDict=None, markerDict=None,
                 loop_bases=loop_bases,
                 loop_setups=loop_setups,
                 figDict={}, fig_key=None):

    fig = fig or figDict.get(fig_key,None) or  plt.gcf()
    q = query.copy()

    ## add a loop on unique base couple
    ## associate the base to axes_key
    q.add_loop(loop_bases, keys="axes_key")

    #q.add_loop("TARGET", "targets", keys="color")
    # add a loop on the setup
    # associate the setup to marker_key
    q.add_loop(loop_setups, keys="marker_key")


    ## build the dictionaries, colorDict and markerDict must be outside the function

    axesDict  = make_base_dictionary(q.cursor, axes_maker,
                                     conds="STA1_INDEX<100") # remove the 'S' station
    colorDict  = colorDict  or make_target_dictionary(q.cursor)
    markerDict = markerDict or make_setup_dictionary(q.cursor)
    ##
    # q.iterexecute() return a database cursor and a dictionary
    # the dictionary gives the curent keys for colorDict, markerDict, etc...
    # for instance { 'axes_key':('AO','B1'), 'marker_key':('LOW','SPLT',..), 'color_key':'HR-4567' }
    # for each iteration the dictionary change
    for c, keys in q.iterexecute():
        ax = plot_cursor(c, ["x", "y", "yerr", "colors_key"],
                    axesDict=axesDict,
                    colorDict=colorDict,
                    markerDict=markerDict,
                    scater=True, **keys
                    )

    return fig

##############################################################################

def plot_triplet(query, columns, orders=["x", "y", "yerr"],
                 fig=None, axes_maker=None, colorDict=None, markerDict=None,
                 loop_triplets=loop_triplets,
                 loop_setups=loop_setups,
                 basetbl="vis2"):

    fig = fig or plt.gcf()
    q = query.copy()

    ## add a loop on unique base couple
    ## associate the base to axes_key
    q.add_loop(loop_triplets, keys="axes_key")

    #q.add_loop("TARGET", "targets", keys="color")
    # add a loop on the setup
    # associate the setup to marker_key
    q.add_loop(loop_setups, keys="marker_key")

    ## build the dictionaries, colorDict and markerDict must be outside the function

    axesDict  = make_triplet_dictionary(q.cursor, axes_maker,
                                        conds="STA1_INDEX<100") # remove the 'S' station
    colorDict  = colorDict  or make_target_dictionary(q.cursor)
    markerDict = markerDict or make_setup_dictionary(q.cursor)
    ##
    # q.iterexecute() return a database cursor and a dictionary
    # the dictionary gives the curent keys for colorDict, markerDict, etc...
    # for instance { 'axes_key':('AO','B1'), 'marker_key':('LOW','SPLT',..), 'color_key':'HR-4567' }
    # for each iteration the dictionary change
    for c, keys in q.iterexecute():
        ax = plot_cursor(c, ["x", "y", "yerr", "colors_key"],
                    axesDict=axesDict,
                    colorDict=colorDict,
                    markerDict=markerDict,
                    scater=True, **keys
                    )

    return fig








###############################################################################





def make_dictionary(db, values, query, row2key=lambda r:r[0]):
    """ values can be an iterable or a function with signature (i,N,r)
    where i is the row index 0,1,2,3 ,....,N ;  N the number of row and
    r is the value of that row
    """
    c = db.execute(query)
    if hasattr(values, "__call__"):
        func = values
        rows = [(row2key(r),r) for r in c]
        Nkeys = len(rows)
        return {k:func(i,Nkeys,r) for i,(k,r) in enumerate(rows)}
    else:
        N = len(values)
        return {row2key(r):values[i%N] for i,r in enumerate(c)}

def make_target_dictionary(db, values=list("bgrkcy")):
    query="SELECT DISTiNCT TARGET FROM targets"
    return make_dictionary(db, values, query, lambda r:r[0])

def make_setup_dictionary(db, values=list("osx21")):
    query="SELECT DISTiNCT [ESO INS SPEC RES],[ESO INS POLA MODE],[ESO INS TIM2 PERIOD],[ESO DET2 SEQ1 DIT] FROM headers"
    return make_dictionary(db, values, query, lambda r:tuple(r))

def axes_maker(fig=None, orientation=0,extra=lambda ax,r:None,
               sharex=False, sharey=False, **kwargs):
    fig = fig or plt.gcf()
    kwsub = {}

    if orientation:
        def make_ax(i,N,r):
            ax = fig.add_subplot(1,N,i+1, **kwsub)
            if not i and sharey: kwsub["sharey"] = ax
            ax.first_axes =  not i
            ax.last_axes  =  i==(N-1)
            ax.set(**kwargs)
            extra(ax,r)

            return ax
    else:
        def make_ax(i,N,r):
            ax = fig.add_subplot(N, 1, i+1, **kwsub)
            if not i and sharex: kwsub["sharex"] = ax
            ax.first_axes =  not i
            ax.last_axes  =  i==(N-1)
            ax.set(**kwargs)
            extra(ax,r)
            return ax
    return make_ax

def _extra_base(ax,r):
    ax.set_title(r[0]+r[1])

def make_base_dictionary(db, maker=None, tbl="vis2", conds=""):
    """ e.g.:  make_base_dictionary(db, axes_maker(fig)) """
    if maker is None:
        maker = axes_maker(extra=_extra_base)
    query = "SELECT DISTINCT ar1.STA_NAME, ar2.STA_NAME FROM %s, arrays as ar1, arrays as ar2 WHERE STA1_INDEX = ar1.STA_INDEX AND STA2_INDEX = ar2.STA_INDEX "%tbl
    if conds: query += " AND "+conds
    return make_dictionary(db, maker, query, lambda r:r[0:2])

def _extra_triplet(ax,r):
    ax.set_title(r[0]+r[1]+r[2])

def make_triplet_dictionary(db, maker=None, tbl="oi_t3", conds=""):
    """ e.g.:  make_triplet_dictionary(db, axes_maker(fig)) """
    if maker is None:
        maker = axes_maker(extra=_extra_triplet)

    query = "SELECT DISTINCT ar1.STA_NAME, ar2.STA_NAME, ar3.STA_NAME FROM %s, arrays as ar1, arrays as ar2, arrays as ar3 WHERE STA1_INDEX = ar1.STA_INDEX AND STA2_INDEX = ar2.STA_INDEX AND STA3_INDEX = ar3.STA_INDEX"%tbl
    if conds: query += " AND "+conds
    return make_dictionary(db, maker, query, lambda r:r[0:3])






