import itertools
import sqlite3
import numpy as np

def first(row):
    return row[0]

def row(r):
    return r

def merge(row):
    return "".join(row)

def merge_(row):
    return "_".join(str(r) for r in row)

def link_tbls_conds(cursor, match_dict):
    conds = []
    tbls = list(match_dict.keys())

    if isinstance(cursor,dict):
        def get_keys(tbl1,tbl2):
            return cursor.get((tbl1,tbl2),[])
    else:
        def get_keys(tbl1, tbl2):
            tmp = cursor.execute("SELECT column FROM LINKS WHERE tbl1 = ? AND tbl2 = ?",(match_dict[t1], match_dict[t2]))
            return [r[0] for r in tmp]

    for i,t1 in enumerate(tbls):
        for j,t2 in enumerate(tbls):
            if j>i:
                break
            c = get_keys(t1,t2)
            for cl in c:
                conds.append( "[%s].[%s] = [%s].[%s]"%(t1,cl,
                                                       t2,cl))

    return " AND ".join(conds)


def _parse_table_name(text):
    for look in [" AS ", " as ", " As ", " aS "]:
        if look in text:
           return tuple(s.strip() for s in text.partition(look)[::-2])
    return text, text

def _clean_table_name(text):
    """  assure that  toto as    tutu """
    r,n = _parse_table_name(text)
    if r!=n:
        return "%s AS %s"%(n,r)
    return text.strip()



class query(object):
    def __init__(self, db, columns=[], tbls=[], conds=[],
                 automatch=False, **kwargs):
        """ define a database query

        Parameters
        ----------
        db : database / cursor or None
            If none the query is not usable until defined by q.set_db(db)
        columns :  list or string
            list of columns to return by the query
        tbls : list or string, optional
            list of tables used by the query
            the list of tables can contains tuple of table names. In this case
            the table defined inside the tuple are matched together with the
            defined links, if any.

        conds : list or string, optional
            list of conditions
        automatch :  bool, optional
            if True tables are automaticaly matched together with the defined links

        Methods
        -------
        add_columns : add more select columns
        add_tbls : add more tables
        add_conds :  add more conditions
        set_columns : set the selections, erase the previous
        set_tbls  :  set the tables, erase the previous
        set_conds :  set the conditions, erase previous

        add_match : add matches definition between columns
        clear_matches : clear all matches

        query : copy and update the query
        execute : execute the query
        fetchone, fetchall, fetchmany :  sent to cursor

        Examples:
        ---------
        q = query(db, "MJD,mean(VIS2DATA,Null,Null),TARGET" , "vis2,targets", automatch=False)
        q.add_match( "vis2", "targets", keys=["FILEID", "TARGET_ID"] )
        q.add_match( "vis2", "headers", keys=["FILEID"] )

        c = q.execute(conds="TARGET LIKE '%HD%'")
        # c is a copy of query so we can continue to nail it
        c.execute( conds=("[ESO PRO CATG] like ?", "%VIS%"), tbls=[("vis2","headers")] ) # the tuple ask to math the 2 table with its default links
        c.execute( "SELECT * FROM  etc ....")  # this send a new query ignoring what is in c

        q = query(db, tbls="header,targets", automatch=True )  # automaticaly match *ALL* matchable tables

        q = query(db, tbls=[("header","targets")], automatch=False )  # The tuple indicate that the tables "header" and "target"
                                                                      # must be mached with default links

        str(q) # to check the string query

        """


        self.set_db(db)

        self.tbls = {}
        self.matches = []
        self.matchkeys = []
        self.automatch = automatch

        self.columns = []

        self.conds = []
        self.condvals = []
        self.posargs = {}
        self.funcargs = {}
        self.scalars = {}

        self.add_columns(columns, **kwargs)
        self.add_tbls(tbls)
        self.add_conds(conds)


    def _get_default_links(self, t1, t2):
        if self.db is None:
            raise TypeError("cannot determine default link, set a db first")
        c = self.db.execute("SELECT DISTINCT column FROM LINKS WHERE tbl1 = ? AND tbl2 = ?",(t1,t2))
        links = c.fetchall()
        if links:
            links, = list(zip(*links))
        else:
            links = []
        return list(zip(links,links))



    def _is_table(self, tbl):
        return True

    def _parse_table(self, text):
        for look in [" AS ", " as ", " As ", " aS "]:
            if look in text:
               tbl, rep =  (t.strip() for t in text.partition(look)[::2])
               break
        else:
            tbl = text.strip()

            d = self.tbls#dict(zip(zip(*self.tbls)[::-1]))
            try:
                return d[tbl], tbl
            except:
                if not self._is_table(tbl):
                    raise KeyError("Unknown table '%s'"%tbl)
            return (tbl,tbl)

        d = self.tbls#dict(zip(zip(*self.tbls)[::-1]))
        if rep in d and d[rep]!=tbl:
            raise KeyError("table with name '%s' already exists and linked to bale %s"%(rep,d[rep]))
        return tbl, rep

    def has_table(self, tbl):
        try:
            tbl, rep = self._parse_table(tbl)
        except KeyError:
            return False
        return rep in self.tbls


    def _parse_tbls(self, tbls, known_tables=None):
        known_tables = known_tables or self._known_tables
        reps = dict([_parse_table_name(t) for t in self.tbls])
        out = []
        for tbl in tbls:
            rp,rl = _parse_table_name(tbl)
            if rp == rl and not rp in reps and not rl in known_tables:
                print(("WARNING: table '%s' is unknown, will not be added in FROM ... list"%rp))
            else:
                out.append(tbl)
        return out


    def set_db(self, db):
        """ accept a sqlite data base or a cursor """
        self.db = db
        if db:
            self.cursor = db.cursor()

    def copy(self):
        new = self.__class__(self.db)

        new.tbls  = self.tbls.copy()#set(self.tbls)


        new.matches   = list([m.copy() for m in self.matches])
        new.matchkeys = list([m.copy() for m in self.matchkeys])
        new.automatch = self.automatch


        new.columns  = list(self.columns)
        new.posargs  = self.posargs.copy()
        new.funcargs = self.funcargs.copy()
        new.conds    = list(self.conds)
        new.condvals = list(self.condvals)
        new.scalars  = self.scalars.copy()

        return new

    def link(self, tables1, tables2, keys=None):
        """ link tables together for the query

        Parameters
        ----------
        tables1 : string or list
            list of 'left' table to link
        tables2 : string or list
            list of 'right' table to link. Must be of same size than table1
        keys : string or list, optional
            list of keys to link. If None (default) the tables are
            linked with their default link keys if any
            keys can be a list of string or a list of 2 tuple in this case
            the tuple element define key of table1, key of table2


        Examples:
        ---------
        q.link( "headers", "vis2")
        q.link("headers as h", "vis2 as v2")
        q.link(["headers as h1", "headers as h2"], ["vis2","vis2"])

        """
        if isinstance(tables1, str):
            tables1 = [t.strip() for t in tables1.split(",")]
        if isinstance(tables2, str):
            tables2 = [t.strip() for t in tables2.split(",")]
        if len(tables1)!=len(tables2):
            raise ValueError("Both tables list must have the same size")
        if isinstance(keys,  str):
            keys = [k.strip() for k in keys.split(",")]

        conds = []

        if keys is None:
            for t1,t2 in zip(tables1,tables2):
                tn1,rep1 = self._parse_table(t1)
                tn2,rep2 = self._parse_table(t2)

                links = self._get_default_links(tn1,tn2)
                conds.extend(["%s.[%s] = %s.[%s]"%(rep1,k1,rep2,k2) for k1,k2 in links] )
        else:
            keys = [(k if isinstance(k,tuple) else (k,k)) for k in keys]
            for t1,t2 in zip(tables1,tables2):
                tn1,rep1 = self._parse_table(t1)
                tn2,rep2 = self._parse_table(t2)
                conds.extend(["%s.[%s] = %s.[%s]"%(rep1,k1,rep2,k2) for k1,k2 in keys] )

        self.add_tbls(tables1)
        self.add_tbls(tables2)
        self.add_conds([c for c in conds if not c in self.conds])

    def linksequence(self, tables1, keys1, tables2, keys2):
        """ link tables with sequence of keys

        Parameters
        ----------
        tables1 : string or list
            sequence of tables names (or string of ',' separted tables)
        keys1 : string or list
            list of keys (columns) for each tables in tables
            if list must be of the same size of tables1
            if string keys1 becomes [keys1]*len(tables1)
        tables2: string or list
            one or more table to link to the sequence of table1
        keys2: string or list
            list of tables2 keys to match keys1. Should be of the
            same size than table1 and keys1
            if string keys2 becomes [keys2]*len(tables1)

        Examples
        --------

        q.linksequence( ["ar1","ar2"], "STA_INDEX", "vis", ["STA1_INDEX", "STA2_INDEX"])
        will build : ar1.STA_INDEX = vis.STA1_INDEX AND ar2.STA_INDEX = vis.STA2_INDEX
        q.linksequence( ["ar1","ar2"], "STA_INDEX", ["vis","vis2"], ["STA1_INDEX", "STA2_INDEX"])
        will do the same for both "vis" and "vis2"

        """
        if isinstance(tables1, str):
            tables1 = [t.strip() for t in tables1.split(",")]
        if isinstance(tables2, str):
            tables2 = [t.strip() for t in tables2.split(",")]
        if isinstance(keys1, str):
            keys1 = [keys1]*len(tables1)
        elif len(keys1)!=len(tables1):
            raise ValueError("keys1 must be scalar string or must have same size than table1")

        if isinstance(keys2, str):
            keys2 = [keys2]*len(tables1)
        elif len(keys2)!=len(tables1):
            raise ValueError("keys2 must be scalar string or must have same size than table1")

        tbrep1 = [self._parse_table(t)[1] for t in tables1]
        tbrep2 = [self._parse_table(t)[1] for t in tables2]

        conds = []
        for t2 in tbrep2:
            conds.extend(["%s.[%s] = %s.[%s]"%(t1,k1,t2,k2) for t1,k1,k2 in zip(tbrep1, keys1, keys2)])

        self.add_tbls(tables1)
        self.add_tbls(tables2)
        self.add_conds([c for c in conds if not c in self.conds])

    def linkall(self, tables1, tables2="*", keys=None):
        if isinstance(tables1, str):
            if tables1== "*":
                tables1 = [" AS ".join(tm) for tm in list(self.tbls.items())]
            else:
                tables1 = [t.strip() for t in tables1.split(",")]

        if isinstance(tables2, str):
            if tables2== "*":
                tables2 = [" AS ".join(tm) for tm in list(self.tbls.items())]
            else:
                tables2 = [t.strip() for t in tables2.split(",")]

        pairs = []
        for t1,t2 in itertools.product(tables1, tables2):
            if t1 == t2: continue
            if (t2,t1) in pairs: continue # avoid to check same pairs
            self.link(t1,t2,keys)
            pairs.append((t1,t2))

    def add_match(self, *alltbls, **kwargs):
        tbls = []
        for tb in alltbls:
            tbls.extend( tb.split(",") if isinstance(tb, str) else tb )

        self.tbls.update(tbls)
        tbls = [_parse_table_name(tb) for tb in tbls]

        # loop first to take an existing match
        # with the same table
        keys_dict = {}
        for replacement, real in tbls:
            for mkeys,match in zip(self.matchkeys,self.matches):
                if replacement in match:
                    match_dict = match
                    match_keys = mkeys
                    exists = True
                    break
            else:
                match_dict = {}
                match_keys = {}
                exists = False
            if exists is not False: break

        match_dict.update(tbls)

        keys = kwargs.get("keys", None)
        if keys:
            ks = list(match_dict.keys())
            pairs = [(x,y) for x in ks for y in ks if x!=y]
            for pair in pairs:
                match_keys[pair] = keys

        if not exists:
            self.matches.append(match_dict)
            self.matchkeys.append(match_keys)

        if exists and keys in [False]: # keys is explicitaly False
            self.matchkeys[self.matchkeys.index(match_keys)] = {}


    def add_tbls(self, tbls):
        if isinstance(tbls, str):
            tbls = [s.strip() for s in tbls.split(",")]

        ## assure that   'toto as t1'
        ## is the same  than: 'toto     AS t2'
        for t in tbls:
            #self.tbls.add(self._parse_table(t))
            tbl, rep = self._parse_table(t)
            self.tbls[rep] = tbl
        return

        reps = dict([_parse_table_name(t) for t in self.tbls])


        for tb in tbls:
            if isinstance(tb,tuple):
                self.add_match(*tb)
                self.add_tbls(tb)
            else:
                rp,rl = _parse_table_name(tb)
                if rp!=rl and rp in reps and reps[rp]!=rl:
                    raise ValueError("table with name '%s' already exists and linked to bale %s"%(rp,reps[rp]))

                if rp not in reps:
                    self.tbls.add(_clean_table_name(tb))


    def set_tbls(self, tbls):
        if isinstance(tbls, str):
            tbls = [s.strip() for s in tbls.split(",")]

        for tb in tbls:
            if isinstance(tb,tuple):
                self.add_match(*tb)
        self.tbls = set(tbls)


    def query(self,  columns=[], tbls=[], conds=[], db=None, **kwargs):
        """ make a new query columns, tbls, conds are added to the new query

        Parameters
        ----------
        columns : string, list, dict
            additional columns
        tbls : string, list
            additional tables
        conds : string, list
            additional condition
        db : change the data base to
        **kwargs : other additional key/keywords pairs

        """
        new = self.copy()
        if db:
            new.set_db(db)
        new.add_columns(columns, **kwargs)
        new.add_tbls(tbls)
        new.add_conds(conds)
        return new

    def add_columns(self, _columns_=[], **kwargs):
        """ add columns to the data base

        Parameters
        ----------
        columns : string, list, dict
            if string, column names must be separated by ','
            if dict, vals are added to columns,
            the keys are saved and returned by fecthdata()
        **kwargs : any other keys/column-name pairs

        Examples:
        ---------
        q = query(db,tbl="vis2")
        q.add_columns( x="MJD", y="mean(vis2,Null,Null)")
        data = q.execute().fetchdata()
        data["x"]
        data["y"]

        # Notes columns can be added in the query definition
        q = query(db, tbl="vis2", x="MJD", y="mean(vis2,Null,Null)")
        """
        if isinstance(_columns_, str):
            _columns_ = [s.strip() for s in _columns_.split(",")]
        elif isinstance(_columns_, dict):
            kwargs = dict(_columns_, **kwargs)
            _columns_ = []

        self.columns.extend(_columns_)
        for k,v in list(kwargs.items()):
            if not isinstance(v,str):
                if not hasattr(v, "__call__"):
                    raise ValueError("If not a string column should be callable got '%s' for '%s'"%(v,k))
                self.funcargs[k] = v
                if k in self.posargs:
                    self.posargs.pop(k)
                continue

            if k in self.posargs:
                self.columns[self.posargs[k]] = v
            else:
                self.posargs[k] = len(self.columns)
                self.columns.append(v)

    def set_columns(self, _columns_=[], **kwargs):
        save = self.columns, self.posargs
        self.columns = []
        self.posargs = {}
        try:
            self.add_columns(_columns_, **kwargs)
        except Exception as e:
            self.columns, self.posargs = save
            raise e


    def add_conds(self, _conds_=[], **kwargs):
        if isinstance(_conds_, (str,tuple)):
            _conds_ = [_conds_]
        elif isinstance(_conds_, dict):
            kwargs = dict(_conds_, **kwargs)
            _conds_ = []

        self.conds.extend(_conds_)

        for k,v, in list(kwargs.items()):
            self.conds.append("%s = ?"%k)
            self.condvals.append(v)



    def set_conds(self, _conds_, **kwargs):
        save = self.conds
        self.conds = []
        try:
            self.add_conds(conds, **kwargs)
        except Exception as e:
            self.conds = save
            raise e

    def iter(self, columns, _conds_=[], _keys_=[], _row2val_=None, **kwargs):
        """ iter on some columns each element of iterator return a new query

        For each element scalars are updated in the return query scalars dictionary
        The scalars are declared by keyword, where the keyword is the scalar name
        and its value a function of signature f(row)

        Parameters
        ----------
        columns : list/ string
            list of columns that will be taken as unique
        _conds_: list/strings, optional
            additional condition for the iterator only
        _keys_: list/ string, optional
            A short way to declare additional scalars
        _row2val_ : callable
            the row2val function for _keys_

        **kwargs :  scalarname=callable
            the scalars function declaration

        Examples:
        ---------

        for q in Q.iter("TARGET",  target=lambda r:r[0]):
            print (q.scalars["target"])
            # do something with q

        for q in Q.iter(["ESO INS SPEC RES", "ESO INS POLA MODE"],  setup=lambda r:"_".join(r)):
            print (q.scalars["setup"])
            # do something with q

        Return
        ------
            generator

        See Also:
        ---------
        niter : the same iterator except that scalar functions accept 3 arguments
                which are n:number of rows, i: index of row, row: curent row
        """
        if isinstance(columns, str):
            columns = [c.strip() for c in columns.split(",") ]
        if _row2val_ is None:
            if len(columns)<=1:
                _row2val_ = lambda r:r[0]
            else:
                _row2val_ = lambda r: r

        if isinstance(_keys_, str):
            _keys_ = [k.strip() for k in _keys_.split(",") ]

        for k in _keys_:
            kwargs.setdefault(k, _row2val_)

        new = self.copy()
        new.add_conds(_conds_)

        new.set_columns(columns)
        c = new.db.execute(*(new.build(distinct=True)[:2]))
        #c = new.execute(distinct=True)

        for row in c:
            tmp = self.copy()
            tmp.columns = list(self.columns)
            tmp.posargs = dict(self.posargs)
            tmp.scalars.update({k:f(row) for k,f in list(kwargs.items())})
            tmp.add_conds({cl:v for cl,v in zip(columns,row)})
            #tmp.cursor = self.cursor
            #yield tmp.cursor.execute(*tmp.build()[:2])
            yield tmp

    def niter(self, columns, _conds_=[], _keys_=[], _row2val_=None, **kwargs):
        """ Same than iter except that all function for scalars got 3 arguments

        all scalars function must be of signature:
            f(n,i,row)
            where:
                n: the total number of rows
                i: the index of the given row
                row: the given row
        """
        if isinstance(columns, str):
            columns = [c.strip() for c in columns.split(",") ]
        if _row2val_ is None:
            if len(columns)<=1:
                _row2val_ = lambda n,i,r:r[0]
            else:
                _row2val_ = lambda n,i,r: r

        if isinstance(_keys_, str):
            _keys_ = [k.strip() for k in _keys_.split(",") ]

        for k in _keys_:
            kwargs.setdefault(k, _row2val_)

        new = self.copy()
        new.add_conds(_conds_)

        new.set_columns(columns)
        c = new.db.execute(*(new.build(distinct=True)[:2]))
        #c = new.execute(distinct=True)
        rows = c.fetchall()
        N = len(rows)
        for i,row in zip(list(range(N)), rows):
            tmp = self.copy()
            tmp.columns = list(self.columns)
            tmp.posargs = dict(self.posargs)
            tmp.scalars.update({k:f(N,i,row) for k,f in list(kwargs.items())})
            tmp.add_conds({cl:v for cl,v in zip(columns,row)})
            #tmp.cursor = self.cursor
            #yield tmp.cursor.execute(*tmp.build()[:2])
            yield tmp


    def clear_matches(self):
        self.matchkeys = []
        self.matches = []



    def _build_tbl_cond(self):
        matches_cond = []

        if not self.automatch:
            for mkeys, match in zip(self.matchkeys, self.matches):
                matches_cond.append(link_tbls_conds(mkeys or self.cursor, match))

        else:
            matches = dict([_parse_table_name(t) for t in  self.tbls])
            for mkeys, match in zip(self.matchkeys, self.matches):
                for k in match: # remove from automatch all none defined matches
                    if k in matches:
                        matches.pop(k)
                matches_cond.append(link_tbls_conds(mkeys or self.cursor, match))
            matches_cond.append(link_tbls_conds(self.cursor, matches))
        return " AND ".join(matches_cond)

    def _build_tbls(self):
        return ", ".join( " AS ".join(ts[::-1]) for ts in list(self.tbls.items()))

    def _build_columns(self):
        return ", ".join(self.columns) or "*"

    def _build_conds(self):
        return " AND ".join(self.conds), tuple(self.condvals)

    def build(self, distinct=False):

        conds, substitutions = self._build_conds()
        all_conds = " AND ".join(s for s in  [conds, self._build_tbl_cond() ] if s.strip())

        return "SELECT %s %s FROM %s %s"%( "DISTINCT" if distinct else "",
                                           self._build_columns(),
                                           self._build_tbls(),
                                           " WHERE "+all_conds if all_conds.strip() else ""
                                         ), substitutions

    def execute(self, _query_=None, _subs_=tuple(), columns=[], tbls=[], conds=[], distinct=False):
        if _query_:
            if len( columns+tbls+conds):
                raise ValueError("With a query string columns, tbls, conds keyword are not allowed")
            return self.cursor.execute(_query_, _subs_)

        cursor = self.cursor
        newquery = self.copy()
        newquery.add_columns(columns)
        newquery.add_tbls(tbls)
        newquery.add_conds(conds)

        newcursor =  cursor.execute(*newquery.build(distinct=distinct))
        newquery.cursor = newcursor
        return newquery

    def merge(self, *others, **kwargs):
        """ merge with other queries

        columns are added, conditions are added with 'AND'
        """
        self = self.copy()
        self.add_columns(kwargs.pop("columns", []))
        self.add_conds(kwargs.pop("conds", []))
        self.add_conds(kwargs.pop("tbls", []))
        if len(kwargs):
            raise TypeError("They are unexpected kwargs : '%s'"%(kwargs))
        for other in others:
            self.conds   += other.conds
            self.condvals += other.condvals
            self.columns += other.columns
            self.tbls.update(other.tbls)

            for mkeys, matches in zip(other.matchkeys, other.matches):
                self.add_match(list(matches.keys()), keys= mkeys if mkeys else None)

        return self

    def __mul__(self, right):
        return self.merge(right)


    def fetchall(self):
        return self.cursor.fetchall()

    def fetchdata(self):
        fetched = self.fetchall()
        if not len(fetched): return {}
        data = {n:np.array(d) for n,d in enumerate(zip(*fetched))}
        for k,pos in list(self.posargs.items()):
            data[k] = data[pos]
        for k,func in list(self.funcargs.items()):
            data[k] = func(data)
        return data

    def fetchone(self):
        return self.cursor.fetchone()

    def fetchmany(self, n):
        return self.cursor.fetchmany(n)

    def __iter__(self):
        return self.cursor.__iter__()

    def __str__(self):
        s, _ = self.build()
        return s




