import sqlite3
import numpy as np
import io
from gravi_reduce import (log, ERROR, WARNING, NOTICE)
log = log.Log().log

class_record = {}
db_table_record = {}

## if True the numpy array will be 
## recorded inside the data base
RECORD_ARRAY_IN_DB = True

class MoreData(object):
    def __init__(self, *lst, **kwargs):
        self.lst = [kwargs]+lst

    def __getitem__(self, item):
        for data in self.lst:
            try:
                v = data[item]
            except KeyError:
                continue
            else:
                return v
        raise KeyError("'%s'"%item)


def record_column(tables, c):
    if not c.tbl in tables:
        tables[c.tbl] = DBTable(c.tbl, [c])
    else:
        tables[c.tbl].update_column(c)


def record_insert(tables, i):
    if not i.tbl in tables:
        tables[i.tbl] = DBTable(i.tbl, [])

def record_table(tables, t):
    if t.name in tables:
        raise TypeError("A table with name '%s' Already exists"%t.name)
    tables[t.name] = t

def update_table(tables, tblname, columns):
    if not tblname in tables:
        tbl = DBTable(columns, tblname)
        tables[tblname] = tbl
    else:
        tbl = tables[tblname]

    for c in tbl.columns:
        tbl.update_column(c)



class MetaDBDefinition(type):
    def __new__(cl, name, parents, dct):

        if not "ftype" in dct:
            raise TypeError("ftype attribute missing in class '%s'"%name)

        if not "parse_file" in dct:
            raise TypeError("the classmethod 'parse_file' is missing in class '%s'"%name)

        ftype = dct["ftype"]
        if ftype is None:
            return super(MetaDBDefinition, cl).__new__(cl, name, parents, dct)


        inserts = []
        columns = []
        tables  = []
        inserts.extend(list(dct.get("inserts", [])))
        columns.extend(list(dct.get("columns", [])))
        tables.extend(list(dct.get("tables", [])))

        for key, obj in list(dct.items()):
            if isinstance(obj, table):
                if obj.name is None:
                    obj.name = key
                tables.append(obj)
            if isinstance(obj, insert):
                obj = obj() #make a copy
                if obj.tbl is None:
                    raise ValueError("missing table name for insert @ '%s'"%key)
                if obj.ftype is None:
                    obj.ftype = ftype

                inserts.append( obj )

            elif isinstance(obj, column):
                if obj.tbl is None:
                   raise TypeError("missing table name for column @ '%s'"%key)
                obj = obj() #make a copy
                if obj.name is None:
                    obj.name = key
                if obj.ftype is None:
                    obj.ftype = ftype
                columns.append(obj)

        dct.update( tables =tables, columns=columns, inserts=inserts)
        newcl = super(MetaDBDefinition, cl).__new__(cl, name, parents, dct)


        for db in dct.get("databases", []):
            db.tables.add_tables(tables)
            db.tables.add_inserts(inserts)
            db.tables.add_columns(columns)
            db.definitions.append(newcl)

        return newcl


class DBDefinition(object, metaclass=MetaDBDefinition):
    databases = []
    ftype = None

    @classmethod
    def parse_file(self, fh):
        raise NotImplementedError('parse_filenot implemented in base DBDefinition')



class DBTables(object):
    def __init__(self, tables=[]):
        self.tables = {t.name:t for t in tables}
        self.inserts = []
        self.updates = []
        self.columns = []

    def __iter__(self):
        return iter(list(self.tables.values()))

    def add(self, obj):
        if isinstance(obj, DBTables):
            self.add_tables(obj)
        elif isinstance(obj, DBTable):
            self.add_table(obj)
        elif isinstance(obj, DBInsert):
            self.add_table(obj)
        elif isinstance(obj, DBColumn):
            self.add_column(obj)

        raise ValueError("Invalid object for add, expecting tables, table, insert or column got '%s'"%type(obj))

    def add_tables(self, tables):
        for table in tables:
            self.add_table(table)

    def add_insert(self, insert):
        if not self.has_table(insert.tbl):
            self.add_table( DBTable(insert.tbl, []))
        self.inserts.append(insert)

    def add_inserts(self, inserts):
        for insert in inserts:
            self.add_insert(insert)


    def get_inserts(self, ftype):
        return [i for i in self.inserts if i.ftype==ftype]

    def add_table(self, table):
        if self.has_table(table.name):
            raise TypeError("A table with name '%s' Already exists"%table.name)
        self.tables[table.name] = table

    def has_table(self, tablename):
        return tablename in self.tables

    def get_table(self, tablename):
        return self.tables[tablename]

    def add_column(self, column):
        if self.has_table(column.tbl):
            self.add_table(DBTable(column.tbl, [column]))
        else:
            self.get_table(column.tbl).update_column(column)
        self.columns.append(column)

    def add_columns(self, columns):
        for column in columns:
            self.add_column(column)

    def get_columns(self, ftype):
        return [c for c in self.columns if c.ftype==ftype]


    def update_table(self,tblname, columns):
        if not self.has_table(tblname):
            tbl = DBTable(columns, tblname)
            self.tables[tbl.name] = tbl
        else:
            tbl = self.tables[tblname]

        for c in tbl.columns:
            tbl.update_column(c)


class DBInsert(object):
    def __init__(self, tbl, finsert=None, many=False, ftype=None):
        # copycolumns = {}
        if isinstance(tbl, DBTable):
            self.tbl = tbl.name
            self.tblobj = tbl

        else:
            self.tbl = tbl
            self.tblobj = None

        if finsert is not None:
            self.finsert = finsert
        self.many = many
        self.ftype = ftype

    def inserter(self, finsert):
        self.finsert = finsert
        return self

    def insert(self, fh, db):
        if self.finsert is None:
            raise TypeError("This insert has not insert func define, use .inserter to define ")

        data = self.finsert(fh, self.tblobj)

        return db.insert(self.tbl, data, self.many)

    @staticmethod
    def finsert(fh, tbl):
        if tbl is None:
            raise TypeError("cannot figure out the table object")

        data = {}
        for c in list(tbl.columns.values()):
            name = c.name
            if c.fget is None:
                raise TypeError("no fget method for column '%s' tbl %s"%(name,tbl.name))
            value = c.fget(fh)
            data[name] = value

        return data

    def __call__(self, tbl=None, **kwargs):
        if tbl is None:
            tbl = self.tblobj if self.tblobj is not None else self.tbl
        for k in ["finsert", "many", "ftype"]:
            kwargs.setdefault(k, getattr(self,k))
        return self.__class__(tbl, **kwargs)




class DBTable(object):
    def __init__(self, name, columns, description='', keys_index=[], links=[]):
        self.columns = {}
        self.name = name
        for c in columns:
            self.update_column(c)


        self.description = description
        self.keys_index = keys_index
        self.links = links

    def update_column(self, column):
        if isinstance(column, DBColumn):
            name, c = column.name, column(tbl=self.name)
        elif isinstance(column, tuple):
            name, dbtype = column
            name, c = name, (DBColumn(name=name, dbtype=dbtype, tbl=self.name))
        else:
            raise ValueError("expecting a column object or a 2 tuple got %s"%column)

        try:
            pc = self.columns[name]
        except KeyError:
            self.columns[name] = c
        else:
            if c.dbtype != pc.dbtype:
                raise TypeError("column '%s' tbl '%s' is already declared but with a dbtype of '%s' instead of '%s",
                                (c.name, self.tbl, pc.dbtype, c.dbtype)
                               )
            if c.description is not None:
                pc.description = c.description

    def __call__(self, name, columns=None, **kwargs):
        if columns is None:
            columns = self.columns.copy()
        for k in ["description", "keys_index", "links"]:
            kwargs.setdefault(k, getattr(self,k))
        return self.__class__(name, columns, **kwargs)




class DBColumn(object):
    def __init__(self, name=None, dbtype="TEXT", description='',
                 tbl=None, many=False, finsert=None, fget=None,
                 ftype=None
                ):
        self.name = name
        self.dbtype = dbtype
        self.description = description
        self.tbl = tbl
        self.many = many
        self.finsert = finsert
        self.fget = fget
        self.ftype = ftype

    def __call__(self, **kwargs):
        for k in ["name", "dbtype", "description", "tbl", "many", "finsert", "fget", "ftype"]:
            kwargs.setdefault(k, getattr(self,k))
        return self.__class__(**kwargs)

    def insert(self, fh, db):
        if self.finsert is None:
            raise TypeError("This tbl has not insert func define, use .inserter to define ")

        data = self.finsert(fh)

        return db.insert(self.tbl, data, self.many)

    def inserter(self, finsert):
        self.finsert = finsert
        return self

    def getter(self, fget):
        self.fget = fget
        return self
    finsert = None
    fget = None

def insertdec(*args, **kwargs):
    i = DBInsert(*args, **kwargs)
    return i.inserter

def columndec(*args, **kwargs):
    tbl = DBColumn(*args, **kwargs)
    return tbl.getter

insert = DBInsert
column = DBColumn
table = DBTable


def db_insert(db, tbl, data, many=False):

        keys_index = db.get_tbl_keys_index(tbl)


        #conds = []
        knames = []
        #dataknames = []



        if many:
            if not isinstance(data,tuple):
                raise ValueError("if many = True, expecting a 2 tuple key,data")
            keys, data = data
        else:
            if isinstance(data, dict):
                keys, data = list(data.keys()), list(data.values())
            elif isinstance(data, tuple):
                keys, data = data
            else:
                # this is alist of k/value pairs
                keys, data = list(zip(*data))

        if not all( [ki in keys for ki in keys_index] ):
            raise ValueError("missing one index key (%s) to insert in db got %s"%(keys_index, keys))


        knames = ["[%s]"%key for key in keys]
        subvals = ",".join(["?"]*len(keys))

        c = db.cursor()
        q = '''INSERT OR REPLACE into [{tbl}]({knames}) VALUES
                   ({subvals})'''.format(tbl=tbl,
                                         knames= ", ".join(knames),
                                         subvals=subvals
                                        )
    
        if many:
            c.executemany(q , data)
        else:
            c.execute(q, data)

        #lid = c.lastrowid
        c.close()
        db.commit()
        return c.lastrowid


class DBKey(object):
    def __init__(self, name, tbl=None, func=None, funcargs=None):
        self.name, self.tbl, self.func, self.funcargs = self.parse_name(name)
        if tbl is not None: self.tbl  = tbl
        if func is not None: self.func = func
        if funcargs is not None: self.funcargs = list(funcargs)

    def __str__(self):
        if self.tbl:
            return "[%s].[%s]"%(self.tbl,self.name)
        else:
            return "[%s]"%self.name

    def parse_name(self, name):
        if "(" in name:
            func, rest = name.split("(",1)
        else:
            k,tbl = name.split(".",1) if "." in name else name,None
            return k.strip(), tbl, None, []

        if not ")" in rest:
            raise TypeError("missing ')' in '%s'"%name)
        rest, garbage = rest.split(")",1)
        args = [r.strip() for r in rest.split(",")]


        if not args:
            raise TypeError("Missing arguments in function")
            return "", None, func, []

        k = args[0]
        args = args[1:]
        if func in list(stat_funcs.keys()):
            if not args:
                args = ['Null']*2
            if len(args)==1:
                args.append('Null')

        for i,a in enumerate(args):
            for t in [int,float]:
                try:
                    a = t(a)
                except:
                    continue
                else:
                    break
            else:
                a = str(a)
                if not (a.startswith("'")  or a.startswith('"') or a.startswith(':')):
                    if not a in ["Null", "?", "INTEGER", "FLOAT", "TEXT"]:
                        a = DBKey(a)
            args[i] = a


        if "." in k:
            k,tbl = k.split(".",1)
        else:
            tbl = None
        return k,tbl,func,args


    @property
    def r(self):
        if self.func:
            fa = ", "+(", ".join(str(a) for a in self.funcargs)) if self.funcargs else ""
            return "%s(%s%s)"%(self.func, self, fa)
        return str(self)

class DBCond(object):
    def __init__(self, *keys, **kwargs):
        op = kwargs.pop("op", "=")
        func = kwargs.pop("func", None)
        if len(kwargs):
            raise TypeError("DBCond accept only two keywords, 'op' and 'func'" )
        if func is None:
            if len(keys)!=2:
                raise ValueError("if no func need 2 keys only got %d"%len(keys))
        self.keys = keys
        self.op = op
        self.func = func

    def __str__(self):
        if self.func:
            return "%s(%s)"%(self.func, ", ".join(str(k) for k in self.keys))
        key1, key2 = self.keys
        return "(%s %s %s)"%(key1, self.op, key2)

class DBFrm(object):
    def __init__(self, tbl):
        self.tbl = tbl
    def __str__(self):
        return "[%s]"%self.tbl

class DBKeys(list):
    def __str__(self):
        return ", ".join([str(k) for k in self])
    @property
    def r(self):
        return ", ".join([k.r if isinstance(k, DBKey) else k for k in self])

    def get(self, key, default=None):
        v, n = self._get(key, default)
        return v
    def _get(self, key, default):
        for k in self:
            if k.name == key:
                return k, False
        if default is None:
            return DBKey(key), True
        return default, True

    def getorappend(self, key, default=None):
        v, new = self._get(key, default)
        if new:
            self.append(v)
        return v

class DBFrms(list):
    def __str__(self):
        return ", ".join([str(f) for f in self])

class DBConds(list):
    def __init__(self, lst=[], op="AND"):
        list.__init__(self, lst)
        self.op = op
    def __str__(self):
        return (" %s "%(self.op)).join([str(c) for c in self])

class DBQuery(object):
    def __init__(self):
        self.keys  = DBKeys()
        self.skeys = DBKeys()
        self.conds = DBConds()
        self.frms  = DBFrms()
        self.substitutions = {}
    def __str__(self):
        if self.conds:
            return "SELECT %s FROM %s WHERE %s"%(self.keys.r, self.frms, self.conds)
        else:
            return "SELECT %s FROM %s"%(self.keys.r, self.frms)



all_data = []




if RECORD_ARRAY_IN_DB:
    def adapt_array(arr):
        out = io.BytesIO()
        np.save(out, arr)
        out.seek(0)
        b = sqlite3.Binary(out.read())
        return b
else:
    def adapt_array(arr):
        """
        http://stackoverflow.com/a/31312102/190597 (SoulNibbler)
        """
        global all_data
        all_data.append(arr.copy())
        return len(all_data)-1

if RECORD_ARRAY_IN_DB:
    def convert_array(text):
        out = io.BytesIO(text)
        out.seek(0)
        a = np.load(out)

        if not a.shape:
            return a.item()
        return a 
else:
    def convert_array(text):
        global all_data
        return all_data[int(text)]

def test_range(v, mn, mx):
    if isinstance(v, str):
        return 0
    if mn is not None:
        if v<mn:
            return 0
    if mx is not None:
        if v>mx:
            return 0
    return 1

def test_range2(text, mn, mx):
    a = convert_array(text)
    if a.shape:
        return 0
    v = a.item()
    if isinstance(v, str):
        return 0


    if mn is not None:
        if v<mn:
            return 0
    if mx is not None:
        if v>mx:
            return 0
    return 1

def test_eq(text, vtst):
    a = convert_array(text)
    if a.shape:
        return 0
    v = a.item()
    return int(v==vtst)

def test_eq2(text1, text2):
    a1 = convert_array(text1)
    a2 = convert_array(text2)



    if a1.shape != a2.shape:
        return 0
    if a1.shape:
        return int(all( (a1==a2).flat ))
    return int(a1==a2)


func_lookup = {"" : lambda a:a,
               None: lambda a:a,
               "mean": np.mean,
               "std": np.std,
               "min": np.min,
               "max": np.max,
               "median": np.nanmedian,
               "sum": np.sum
               }

def return_reduce(text, i1, i2, func):
    v = convert_array(text)
    if not isinstance(v, np.ndarray):
        return v
    v = v[i1:i2]

    if not func in func_lookup:
        raise Exception("unknown reduce func '%s' "%func)
    return func_lookup[func](v)

def dec_reduce(func):
    def ret_reduce(text, i1, i2):
        i1 = int(i1) if i1 else None
        i2 = int(i2) if i2 else None
        v = convert_array(text)

        v = v[i1:i2]
        r = func(v)
        return float(r)

    return ret_reduce

sqlite3.register_adapter(np.ndarray, adapt_array)
# Converts TEXT to np.array when selecting
sqlite3.register_converter("array", convert_array)

stat_funcs = {"mean":dec_reduce(np.mean),
              "std":dec_reduce(np.std),
              "median":dec_reduce(np.median),
              "min":dec_reduce(np.min),
              "max":dec_reduce(np.max),
              "sum":dec_reduce(np.sum)
            }


class DB(object):
    tables = DBTables()
    definitions = []
    def __init__(self, filename=":memory:"):
        self.db = sqlite3.connect(filename, detect_types=sqlite3.PARSE_DECLTYPES)
        c = self.db.cursor()
        c.execute("""SELECT count(*) FROM sqlite_master WHERE type='table' AND name='files'""")
        r = c.fetchone()
        if not r[0]:
            self._create_tables()

        self._build_user_table()

        self._key_indexes_ = {}
        self.db.create_function("test_range", 3, test_range)
        self.db.create_function("test_range2", 3, test_range)
        self.db.create_function("test_eq", 2, test_eq)
        self.db.create_function("test_eq2", 2, test_eq2)
        self.db.create_function("return_reduce", 4, return_reduce)

        for name,f in list(stat_funcs.items()):
            self.db.create_function(name,3,f)

    def cursor(self):
        return self.db.cursor()

    def execute(self, *args):
        return self.db.execute(*args)

    def executescript(self, *args):
        return self.db.executescript(*args)

    def commit(self):
        return self.db.commit()

    def _create_tables(self):
        script = '''
            CREATE TABLE TBLS(id INTEGER PRIMARY KEY, name TEXT, description TEXT, keys_index TEXT, ccount INTEGER);
            CREATE TABLE CLMS(id INTEGER PRIMARY KEY, name TEST, dbtype TEXT, tbl TEXT, description TEXT);
            CREATE TABLE LINKS(tbl1 TEXT, tbl2 TEXT, column TEXT, PRIMARY KEY(tbl1,tbl2,column))
            '''
        c = self.cursor()
        c.executescript(script)
        return self.commit()

    def _build_user_table(self):
        tables = self.tables

        for table in tables:
            self.add_table(table.name, table.columns, table.description, table.keys_index, table.links)

    def parse_file(self, fh):
        raise NotImplementedError("parse_file is not implemented in DB definition")


    def file_exists(self,fh):
        return False
        # mjdobs = fh[0].header.get("MJD-OBS", 0.0)
        # c = self.cursor()
        # c.execute("SELECT count(*) FROM headers WHERE [FILEID] = ?", (mjdobs,))
        # return c.fetchone()[0]>0

    def feed(self, fh):
        count = 0
        for d in self.definitions:

            try:
                fh = d.parse_file(fh)
            except ValueError:
                continue
            else:
                fname = fh[0].fileinfo()['file'].name  if fh[0].fileinfo() else "?"
                
                if self.file_exists(fh):                    
                    log("file exist in db '%s'"%fname, 1, NOTICE)
                    break
                    
                log("Feeding file '%s'"%fname, 1, NOTICE)

                ftype = fh.ftype
                for insert in  self.tables.get_inserts(ftype):
                    insert.insert(fh, self)

                for column in self.tables.get_columns(ftype):
                    column.insert(fh, self)
                count += 1    
                break
        return count        

    def feedall(self, lst):
        count = 0
        for fh in lst:
            count += self.feed(fh)
        return count        

    def add_table(self, name, columns, description='', keys_index=[], links=[]):
        c = self.db.cursor()
        q = """SELECT count(*) FROM TBLS WHERE name=?"""
        c.execute(q, (name,))
        if c.fetchone()[0]:
            # table already exists
            return

        c.execute("INSERT INTO TBLS(name,description,keys_index) VALUES(?,?,?)",(name,description,",".join(keys_index)))


        for l in links:
            tbl_linked, keys = l
            subs = [ (name, tbl_linked, k) for k in keys]
            q = "INSERT OR REPLACE INTO LINKS(tbl1,tbl2,column) VALUES(?,?,?)"
            c.executemany(q, subs)
            q = "INSERT OR REPLACE INTO LINKS(tbl2,tbl1,column) VALUES(?,?,?)"
            c.executemany(q, subs)

        fnames = []
        for c in list(columns.values()):
            self.add_column(name, c.name, c.dbtype, c.description)
            fnames.append("[%s] %s"%(c.name, c.dbtype))

        c = self.cursor()
        tname = "[%s]"%name
        primary_keys = ["[%s]"%k for k in keys_index]
        strpprim = ", PRIMARY KEY("+",".join(primary_keys)+")" if primary_keys else ""

        q = '''CREATE TABLE %s(%s %s)'''%(tname, ",".join(fnames), strpprim)
        c.execute(q)
        return self.commit()



    def add_column(self, tbl, name, dbtype, description=''):
        c = self.cursor()
        c.execute("SELECT id from CLMS WHERE tbl = ? AND name = ?",(tbl,name))
        dbk = c.fetchone()
        if dbk is None:
            c.execute("INSERT INTO CLMS(tbl,name,dbtype,description) VALUES(?,?,?,?)",(tbl,name,dbtype,description))
            self.commit()
            return c.lastrowid
        return dbk[0]

    def get_tbl_keys_index(self, tbl):
        if tbl in self._key_indexes_:
            return self._key_indexes_[tbl]
        c = self.cursor()
        c.execute('''SELECT keys_index FROM TBLS WHERE name = ?''',(tbl,))
        r = c.fetchone()
        if r is None:
            raise ValueError("unknown table '%s'"%tbl)

        keys = r[0].split(",")
        keys = [k for k in keys if k] # remove empy strings
        return keys

    def get_tables_of_keys(self, keys):
        c = self.cursor()
        c.execute("SELECT tbl FROM CLMS WHERE name IN (%s)"%(",".join(["?"]*len(keys))),
                      keys)
        return [r[0] for r in c]


    insert = db_insert




