from .db import (DB, DBTables, DBDefinition, MetaDBDefinition,
                 DBFrm, DBKey, DBCond,
                 table, column, insert, columndec, insertdec,
                 convert_array
                 )

from gravi_reduce import (log, ERROR, WARNING, NOTICE)
log = log.Log().log

#try:
#from astropy.io import fits
#except:
#    import pyfits as fits
import numpy as np
from .date_manips import mjd_to_datetime
from datetime import datetime, timedelta

FILEID_TYPE = "FLOAT"
# FILEID_TYPE = "INTEGER"

def get_fileid(fh):
    return fh[0].header['DATAMD5']
    # return int(fh[0].header['MJD-OBS']*10000)*1000 + len(fh[0].header['ESO PRO CATG'])*100000*1000

def keyword(key, dbtype="TEXT", hdu="PRIMARY"):
    if isinstance(key, tuple):
        hkey, name = key
    else:
        hkey, name = key, key

    @columndec(name=name, dbtype=dbtype)
    def get_keyword(fh):
        return fh[hdu].header[hkey]
    return get_keyword


def header(name, keys, hdu="PRIMARY", description='primary header', keys_index=("MJD-OBS",), links=[]):
    for k in keys_index:
        if not k in keys:
            keys.append((k,"TEXT"))

    tbl = table(name, [keyword(*k) for k in keys], description=description, keys_index=keys_index, links=links)
    return tbl
headerinsert = insert


def det(insname):
    if "FT" in insname: return "FT"
    if "SC" in insname: return "SC"
    return ""

def polar(insname):
    if "P1" in insname: return 1
    if "P2" in insname: return 2
    return 0

def basename(STA1_NAME, STA2_NAME):
    return STA1_NAME+STA2_NAME

def tripletname(STA1_NAME, STA2_NAME, STA3_NAME):
    return STA1_NAME+STA2_NAME+STA3_NAME

def baseline(u,v):
    u = float(u)
    v = float(v)
    return float(np.sqrt(u*u+v*v))

def baselineangle(u,v):
    u = float(u)
    v = float(v)
    return float(np.arctan2(u,v)*180/np.pi)

def baseline2(w, xyz1, xyz2):
    w = convert_array(w)
    xyz1 = convert_array(xyz1)
    xyz2 = convert_array(xyz2)

    baseline2 =  (xyz1[0]-xyz2[0])**2 +\
                 (xyz1[1]-xyz2[1])**2 +\
                 (xyz1[3]-xyz2[3])**2
    return float( np.sqrt(baseline2/w) )

def update_median_1(db, table, colname):
    try:
        db.execute("ALTER TABLE "+table+" ADD "+colname+"_MED FLOAT")
    except:
        pass
    
    log("Get median value of "+colname, 1, NOTICE)

    med = db.execute("SELECT median("+colname+", WImin, WImax),MJD,"+table+".FILEID,STA1_INDEX,"+table+".INSNAME from "+
                     table+",wavelengths where wavelengths.[FILEID]="+table+".[FILEID] and wavelengths.[INSNAME]="+table+".[INSNAME]").fetchall()
    
    log("Set median value of "+colname, 1, NOTICE)
    db.db.executemany("UPDATE "+table+" SET "+colname+"_MED=? where MJD=? and FILEID=? and STA1_INDEX=? and INSNAME=?", med)

def update_median_2(db, table, colname):
    try:
        db.execute("ALTER TABLE "+table+" ADD "+colname+"_MED FLOAT")
    except:
        pass
    
    log("Get median value of "+colname, 1, NOTICE)
    med = db.execute("SELECT median("+colname+", WImin, WImax),MJD,"+table+".FILEID,STA1_INDEX,STA2_INDEX,"+table+".INSNAME from "+
                     table+",wavelengths where wavelengths.[FILEID]="+table+".[FILEID] and wavelengths.[INSNAME]="+table+".[INSNAME]").fetchall()
    
    log("Set median value of "+colname, 1, NOTICE)
    db.db.executemany("UPDATE "+table+" SET "+colname+"_MED=? where MJD=? and FILEID=? and STA1_INDEX=? and STA2_INDEX=? and INSNAME=?", med)

def update_median_3(db, table, colname):
    try:
        db.execute("ALTER TABLE "+table+" ADD "+colname+"_MED FLOAT")
    except:
        pass
    
    log("Get median value of "+colname, 1, NOTICE)
    med = db.execute("SELECT median("+colname+", WImin, WImax),MJD,"+table+".FILEID,STA1_INDEX,STA2_INDEX,STA3_INDEX,"+table+".INSNAME from "+
                     table+",wavelengths where wavelengths.[FILEID]="+table+".[FILEID] and wavelengths.[INSNAME]="+table+".[INSNAME]").fetchall()
    
    log("Set median value of "+colname, 1, NOTICE)
    db.db.executemany("UPDATE "+table+" SET "+colname+"_MED=? where MJD=? and FILEID=? and STA1_INDEX=? and STA2_INDEX=? and STA3_INDEX=? and INSNAME=?", med)


def nightof(mjd):
    date = mjd_to_datetime(float(mjd))
    if date.hour<12:
        date = date-timedelta(1)
    return datetime.strftime(date, "%Y-%m-%d")

class OIVisDB(DB):
    tables = DBTables()
    definitions = []

    def __init__(self, *args, **kwargs):
        super(OIVisDB,self).__init__(*args,**kwargs)
        self.db.create_function("det" , 1, det)
        self.db.create_function("polar", 1, polar)
        self.db.create_function("basename", 2, basename)
        self.db.create_function("tripletname", 3, tripletname)
        self.db.create_function("nightof", 1,  nightof)
        self.db.create_function("baseline", 2, baseline)
        self.db.create_function("baselineangle", 2, baselineangle)

        self.db.create_function("baseline2", 3, baseline2)

    def file_exists(self, fh):
        id = get_fileid(fh)
        c = self.execute("SELECT count(*) FROM headers where FILEID = ? ",  (id,))
        return c.fetchone()[0]

    def setup_query(self, query, num):
        c = self.cursor()
        c.execute("SELECT setup_keys, setup_vals FROM setups")
        data = [r for r in c]
        if num>=len(data):
            raise ValueError("There is only %d setups got index of '%d'"%(len(data), num))
        keys, vals = data[num]
        query.conds.extend( DBCond(DBKey(k), ":setup_val%d"%i) for i,k in enumerate(keys))
        query.substitutions.update(
            {"setup_val%d"%i:v.item() for i,v in enumerate(vals)}
            )
        if not "headers" in [f.tbl for f in query.frms]:
            query.frms.append(DBFrm("headers"))

    def base_query(self, q, base):
        t1, t2 = base

        q.conds.append( DBCond(
                       DBCond(DBCond(DBKey("STA1_index"), t1),  DBCond(DBKey("STA2_index"), t2), op="AND"),
                       DBCond(DBCond(DBKey("STA1_index"), t2),  DBCond(DBKey("STA2_index"), t1), op="AND"),
                       op="OR") )
        #if not "vis2" in [f.tbl for f in q.frms]:
        #    q.frms.append(DBFrm("vis2"))


    def update_wave_index(self, wrange, **kwargs):
        wmin, wmax = wrange

        if wmin is not None and wmax is not None:
            test = lambda a: (a>=wmin) * (a<=wmax)
        elif wmin is not None:
            test = lambda a: (a>=wmin)
        elif wmax is not None:
            test = lambda a: (a<=wmax)
        else:
            test = lambda a:np.ones(a.shape, dtype=np.bool)

        c = self.execute("SELECT DISTINCT FILEID, INSNAME, EFF_WAVE FROM wavelengths")
        #toinsert = []
        for fileid, insname, w in c:
            find, = np.where(test(w))
            if not find.size:
                raise ValueError("wavelengths out of range")
            WImin = int(find[0])
            WImax = int(find[-1]+1)

            ins = self.execute("UPDATE wavelengths SET WImin = ?, WImax = ?, Wmin = ?, Wmax = ? WHERE FILEID = ? and INSNAME = ?",
                               (WImin, WImax, wmin, wmax, fileid, insname)
                               )
            self.commit()

        update_median_1(self, "flux", "FLUX")
        update_median_1(self, "flux", "FLUXERR")
        update_median_2(self, "vis2", "VIS2DATA")
        update_median_2(self, "vis2", "VIS2ERR")
        update_median_2(self, "vis", "VISAMP")
        update_median_2(self, "vis", "VISAMPERR")
        update_median_2(self, "vis", "VISPHI")
        update_median_2(self, "vis", "VISPHIERR")
        update_median_3(self, "t3", "T3PHI")
        update_median_3(self, "t3", "T3PHIERR")
        update_median_3(self, "t3", "T3AMP")
        update_median_3(self, "t3", "T3AMPERR")
        self.commit()

        #return self.insert("waveindex", (["FILEID","INSNAME", "WImin", "WImax", "Wmin", "Wmax"], toinsert), many=True)

    def delete_protacg(self, procatg):
        self.execute(' DELETE FROM vis WHERE EXISTS ( SELECT * FROM headers WHERE headers.FILEID = vis.FILEID and headers.[ESO PRO CATG] like "'+procatg+'")')
        self.execute(' DELETE FROM vis2 WHERE EXISTS ( SELECT * FROM headers WHERE headers.FILEID = vis2.FILEID and headers.[ESO PRO CATG] like "'+procatg+'")')
        self.execute(' DELETE FROM t3 WHERE EXISTS ( SELECT * FROM headers WHERE headers.FILEID = t3.FILEID and headers.[ESO PRO CATG] like "'+procatg+'")')
        self.execute(' DELETE FROM arrays WHERE EXISTS ( SELECT * FROM headers WHERE headers.FILEID = arrays.FILEID and headers.[ESO PRO CATG] like "'+procatg+'")')
        self.execute(' DELETE FROM wavelengths WHERE EXISTS ( SELECT * FROM headers WHERE headers.FILEID = wavelengths.FILEID and headers.[ESO PRO CATG] like "'+procatg+'")')
        self.execute(' DELETE FROM headers where headers.[ESO PRO CATG] like "'+procatg+'"')

    def delete_constrains(self, const):
        self.execute(' DELETE FROM vis  WHERE '+const)
        self.execute(' DELETE FROM vis2 WHERE '+const)
        self.execute(' DELETE FROM t3   WHERE '+const)
        self.execute(' DELETE FROM flux WHERE '+const)

    def _prepare_query(self, q, select={}, setup=None, base=None, **kwargs):
        select.update(kwargs)
        self.search_query(q, select)
        if base is not None:
            self.base_query(q, base)

        self.link_tbls_query(q)
        if setup is not None:
            self.setup_query(q, setup)




    def update_basename(self, **kwargs):
        # Fill BASENAME for vis2
        base = self.execute("SELECT ar1.STA_NAME,ar2.STA_NAME,MJD,INSNAME,STA1_INDEX,STA2_INDEX,vis2.FILEID from vis2,arrays as ar1,arrays as ar2 where ar1.[FILEID]=vis2.[FILEID] and ar2.[FILEID]=vis2.[FILEID] and ar1.STA_INDEX=vis2.STA1_INDEX and ar2.STA_INDEX=vis2.STA2_INDEX").fetchall()
        basename = [ (t1+"-"+t2,m,ins,i1,i2,id) for t1,t2,m,ins,i1,i2,id in base]
        self.db.executemany("UPDATE vis2 SET BASENAME=? where MJD=? and INSNAME=? and STA1_INDEX=? and STA2_INDEX=? and FILEID=?",basename)

        # Fill BASENAME for vis
        base = self.execute("SELECT ar1.STA_NAME,ar2.STA_NAME,MJD,INSNAME,STA1_INDEX,STA2_INDEX,vis.FILEID from vis,arrays as ar1,arrays as ar2 where ar1.[FILEID]=vis.[FILEID] and ar2.[FILEID]=vis.[FILEID] and ar1.STA_INDEX=vis.STA1_INDEX and ar2.STA_INDEX=vis.STA2_INDEX").fetchall()
        basename = [ (t1+"-"+t2,m,ins,i1,i2,id) for t1,t2,m,ins,i1,i2,id in base]
        self.db.executemany("UPDATE vis SET BASENAME=? where MJD=? and INSNAME=? and STA1_INDEX=? and STA2_INDEX=? and FILEID=?",basename)

        # Fill BASENAME for t3
        base = self.execute("SELECT ar1.STA_NAME,ar2.STA_NAME,ar3.STA_NAME,MJD,INSNAME,STA1_INDEX,STA2_INDEX,STA3_INDEX,t3.FILEID from t3,arrays as ar1,arrays as ar2,arrays as ar3 where ar1.[FILEID]=t3.[FILEID] and ar2.[FILEID]=t3.[FILEID] and ar3.[FILEID]=t3.[FILEID] and ar1.STA_INDEX=t3.STA1_INDEX and ar2.STA_INDEX=t3.STA2_INDEX and ar3.STA_INDEX=t3.STA3_INDEX").fetchall()
        basename = [ (t1+"-"+t2+"-"+t3,m,ins,i1,i2,i3,id) for t1,t2,t3,m,ins,i1,i2,i3,id in base]
        self.db.executemany("UPDATE t3 SET BASENAME=? where MJD=? and INSNAME=? and STA1_INDEX=? and STA2_INDEX=? and STA3_INDEX=? and FILEID=?",basename)

        # Fill BASENAME for flux
        base = self.execute("SELECT ar1.STA_NAME,MJD,INSNAME,STA1_INDEX,flux.FILEID from flux,arrays as ar1,arrays as ar2 where ar1.[FILEID]=flux.[FILEID] and ar1.STA_INDEX=flux.STA1_INDEX").fetchall()
        basename = [ (t1,m,ins,i1,id) for t1,m,ins,i1,id in base]
        self.db.executemany("UPDATE flux SET BASENAME=? where MJD=? and INSNAME=? and STA1_INDEX=? and FILEID=?",basename)


