try:
    from astropy.io import fits as pyfits
except:
    import pyfits
import numpy as np
import glob
import os
import re
import sys
from joblib import Parallel, delayed
import shutil
import time

from .log import Log, WARNING, NOTICE
log = Log().log

#
# Load files and define usefull instance (ex: tag)
#


def get_tag(h):
    '''Get the SOF tag for this header'''
    # Has a PRO.CATG
    if 'HIERARCH ESO PRO CATG' in h:
        procatg = h['HIERARCH ESO PRO CATG']
        if procatg == 'JSDC_CAT':
            return 'DIAMETER_CAT'
        elif re.match('TF_*.', procatg):
            return procatg[3:]+"_TF"
        elif re.match('VIS_*.RAW', procatg):
            return procatg[4:-4]+"_VIS"
        else:
            return procatg
    # Has no PRO.CATG, return TYPE_RAW
    elif 'HIERARCH ESO DPR TYPE' in h:
        tag = h['HIERARCH ESO DPR TYPE']
        if 'SKY' in tag:
            tag = tag[4:]+'_SKY'
        elif 'OBJECT' in tag:
            tag = tag[7:]+"_"+h['HIERARCH ESO DPR CATG'][:3]
        elif 'STD' in tag:
            tag = tag[4:]+"_"+h['HIERARCH ESO DPR CATG'][:3]
        elif 'WAVE,SC' in tag:
            tag = 'WAVESC'
        elif 'WAVE,LAMP' in tag:
            tag = 'WAVELAMP'
        return tag+'_RAW'
    else:
        return 'UNKNOWN'


def get_shutters(h):
    '''Get the shutter string for this header'''
    try:
        out = ''
        for i in range(1, 5):
            out += 'T' if h['HIERARCH ESO INS SHUT1%i ST' % i] else 'F'
        return out
    except:
        return 'TTTT'


def add_gravity(headers, files):
    '''Add gravity headers to the list'''
    files.sort()
    for f in files:
        if len(headers) > 0 and f in [h.name for h in headers]:
            continue

        # Get entire FITS header
        try:
            hdulist = pyfits.open(os.path.realpath(f))
        except IOError as e:
            raise IOError("IOError in file " + os.path.realpath(f)+": "+str(e))
        h = hdulist[0].header.copy()
        hdulist.close()

        # Set additional information in header
        h.name = f
        h.tag = get_tag(h)
        h.shutters = get_shutters(h)
        h.is_raw = False if 'HIERARCH ESO PRO CATG' in h else True
        h.mode = 'DUAL' if 'DUAL' in h.tag else 'SINGLE'

        # Keep or reject
        if 'TEST' in h.tag or 'UNKNOWN' in h.tag:
            log('Skip %s as %s' % (h.name, h.tag), 1, NOTICE)
        else:
            log('Load %s as %s' % (h.name, h.tag), 1, NOTICE)
            headers.append(h)
    return headers


#
# Search for headers with a given tag and
# matching criteria
#
opti_ft_setup = ['HIERARCH ESO FT POLA MODE']
opti_sc_setup = ['HIERARCH ESO INS SPEC RES', 'HIERARCH ESO INS POLA MODE']
opti_setup = opti_sc_setup + opti_ft_setup

det_setup_ftgain = ['HIERARCH ESO INS DET3 GAIN']
det_setup_ftdit = ['HIERARCH ESO DET3 SEQ1 DIT']
det_setup_scdit = ['HIERARCH ESO DET2 SEQ1 DIT']
det_setup_win = ['HIERARCH ESO INS FDDL WINDOW']
det_setup = det_setup_ftdit+det_setup_ftgain+det_setup_scdit+det_setup_win

calib_setup = opti_setup + ['ESO PRO NIGHT OBS']
tpl_setup = ['HIERARCH ESO TPL START']

met_setup = ['HIERARCH ESO INS MET MODE']


def assoc(h, allh, tag, keys, prefered_key=[], which='closest', required=0):
    '''Search for headers with tag and matching criteria'''
    # Keep only the requested tag matching the criteria
    atag = [a for a in allh if a.tag == tag]
    out = []
    for a in atag:
        tmp = True
        for k in keys:
            tmp *= (h.get(k, 0) == a.get(k, 0))
        if tmp:
            out.append(a)

    # Check required
    if len(out) < required:
        log('Cannot find %i %s for %s' % (required, tag, h.name), 1, WARNING)

    # if more than required, check for frames with addition of the prefered_key
    if len(out) > required:
        out2 = []
        for a in atag:
            tmp = True
            for k in keys+prefered_key:
                tmp *= (h.get(k, 0) == a.get(k, 0))
            if tmp:
                out2.append(a)

    # Check if out2 is equal or above required, if so, use outE
        if len(out2) >= max([required, 1]):
            out = out2

    # Check closest
    if len(out) > required and which == 'closest':
        # Case need closest and more than 1 not supported yet
        if required < 2:
            time_diffs = np.array([o['MJD-OBS'] - h['MJD-OBS'] for o in out])
            out = [out[np.abs(time_diffs).argmin()]]

    # Check closest
    if len(out) > required and which == 'before':
        # Case need closest and more than 1 not supported yet
        if required < 2:
            time_diffs = np.array([o['MJD-OBS'] - h['MJD-OBS'] for o in out])
            time_diffs_before = np.argwhere(time_diffs < 0 );
            if len(time_diffs_before) > 0:
                out = [out[time_diffs_before[(np.abs(time_diffs[time_diffs_before])).argmin()][0]]];
            else:
                log('Cannot find %s before, thus use after' % tag, 1, WARNING);
                out = [out[(np.abs(time_diffs)).argmin()]];
        else:
            log('Case with "before" and more than 1 is not supported yet (report)', 1, WARNING);
    return out


def new_file_in_dir(directory, prevhash_point):
    """ check if there is a new file in directory
        prevhash_point must be a list with one ellement (act like a pointer) and
        is modified by the function
    """
    prevhash = prevhash_point[0]
    newhash = tuple(name for name in os.listdir(directory) if os.path.isfile(name) )
    prevhash_point[0] = newhash
    return prevhash != newhash


#
# Execute the SOF
#
def run_sof(sof, name=None, recipe='unknown', options=None, overwrite=False,
            outputdir='reduced', averagesky=False, multithread=False,
            onlycreate=False):

    # Define the name of SOF and of executable files
    master = os.path.splitext(os.path.basename(name))[0]
    sof_filename = outputdir+'/'+master+'_esorex.sof'
    cmd_filename = outputdir+'/'+master+'_esorex.sh'

    # Check outputdir exist
    if not os.path.exists(outputdir):
        os.mkdir(outputdir)

    # Check if product exist (defined as outputdir/master_*fits)
    if (not overwrite) and len(glob.glob("%s/%s_*.fits" % (outputdir,master))):
        log("Skiping %s : product already exist" % master, 1, NOTICE)
        return

    # Esorex options, returned as dictionary {'opt-name':'value'}
    e_opt = get_esorex_options(options, 'esorex')
    e_opt['suppress-prefix'] = 'TRUE'
    e_opt['log-file'] = master+'_esorex.log'
    # e_opt['output-prefix'] = master
    e_opt['check-sof-exist'] = 'TRUE'

    if not multithread:
        e_opt['log-dir'] = outputdir
        e_opt['output-dir'] = outputdir
        # Write the sof
        log("Write %s " % sof_filename, 1, NOTICE)
        sof_file = open(sof_filename, "w")
        for s in sof:
            sof_file.write('%s    %s\n' % (s.name,s.tag))
        sof_file.close()

    else:
        e_opt['log-dir'] = '../' + outputdir
        e_opt['output-dir'] = '../' + outputdir
        # Write the sof
        log("Write %s " % sof_filename, 1, NOTICE)
        sof_file = open(sof_filename, "w")
        for s in sof:
            if s.name[0] != '/':
                sof_file.write('../%s    %s\n' % (s.name,s.tag))
            else:
                sof_file.write('%s    %s\n' % (s.name,s.tag))
        sof_file.close()
        sof_filename = '../' + sof_filename

    # Build the esorex option string
    e_options = ' '.join(['--'+k+'='+e_opt[k] for k in list(e_opt.keys()) if e_opt[k] is not None])

    # Recipe options, returned as dictionary {'opt-name':'value'}
    r_opt = get_recipe_options(options, recipe)
    if averagesky:
        r_opt['average-sky'] = 'TRUE'

    # Build the recipe option string
    r_options = ' '.join(['--'+k+'='+r_opt[k] for k in list(r_opt.keys()) if r_opt[k] is not None])

    # Build the esorex command
    cmd = 'esorex %s %s %s %s' % (e_options, recipe, r_options, sof_filename)

    # Write the esorex command and make it executable
    log("Write %s " % cmd_filename, 1, NOTICE)
    cmd_file = open(cmd_filename, "w")
    cmd_file.write(cmd)
    cmd_file.close()
    os.chmod(cmd_filename, 0o775)

    # Execute esorex command
    # log("Execute %s "%cmd_filename, 1, NOTICE)
    sys.stdout.flush()
    if not multithread and not onlycreate:
        erc = os.system(cmd_filename)
        os.utime(sof_filename, None)
        os.utime(cmd_filename, None)
        if erc:
            log("shell esorex command returned error code %d" % erc, 1, WARNING)
        return erc
    else:
        return cmd_filename


def get_esorex_options(argoptions, recipe="esorex"):
    options = {None: None}
    if argoptions is None:
        return options

    dic = vars(argoptions)
    for o in dic:
        if o.split('.')[0] == recipe and dic[o] is not None:
            options[o.split('.')[1]] = dic[o]
    return options

def get_recipe_options(argoptions, recipe):
    options = {None: None}
    if argoptions is None:
        return options

    dic = vars(argoptions)
    for o in dic:
        if (o.split('.')[0] == recipe) or (o.split('.')[0] == '--'+recipe) and (dic[o] is not None):
            options[o.split('.')[1]] = dic[o]
    return options

def implement_recipe_options(parser, options):
    for o in options:
        recipe = o[0].split('.')[0].replace('esorex','')
        parser.add_argument('--'+o[0], dest=o[0], default=o[1],
                            help='(see esorex -h '+recipe+')',
                            metavar='')


def verbose_sof(i, sofs):
    log("", 1, NOTICE)
    log("", 1, NOTICE)
    log("***** Now reducing %s %i over %i *****" % (sofs[i].tag, i+1, len(sofs)), 1, NOTICE)
    log("", 1, NOTICE)


#
# High-level routine to reduce the DARK_RAW, P2VM_RAW,
# DISP_RAW, OBJECT_RAW...
#
# All build following the same sctructure:
# - search for the list of individual reduction, then for each
# - associate RAW and CALIB files to have a SOF
# - run esorex
#

###
def reduce_piezotf_raw(raw, outputdir, options=None):
    piezotfs = [h for h in raw if (h.tag == "PIEZOTF_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_piezoTF') or (h.tag == "VLTITF_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_vltiTF')]
    for i, piezotf in enumerate(piezotfs):
        verbose_sof(i, piezotfs)
        sof = [piezotf]
        run_sof(sof, name=piezotf.name, recipe='gravity_piezo', options=options,
                outputdir=outputdir)


###
def reduce_dark_raw(raw, outputdir, options=None):
    darks = [h for h in raw if h.tag == "DARK_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_dark' ]
    for i, dark in enumerate(darks):
        verbose_sof(i, darks)
        sof = [dark]
        run_sof(sof, name=dark.name, recipe='gravity_dark', options=options,
                outputdir=outputdir)


###
# reduce WAVE (but why recipe p2vm??)
###
def reduce_wave_raw(raw, calibs, outputdir, options=None):
    waves = [h for h in raw if h.tag == "WAVE_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_wave']
    for i, wave in enumerate(waves):
        verbose_sof(i, waves)
        sof = [wave]
        # FIXME: shall we check for the same DIT or same TPL.START in the DARK ??
        sof += assoc(wave, calibs, 'DARK', opti_setup, required=1)
        sof += assoc(wave, calibs, "BAD", opti_setup, required=1)
        sof += assoc(wave, calibs, "FLAT", opti_setup, required=1)
        sof += assoc(wave, raw, 'WAVESC_RAW', opti_setup+tpl_setup, required=1)
        sof += assoc(wave, raw+calibs, "WAVE_PARAM", [])
        run_sof(sof, name=wave.name, recipe='gravity_p2vm', options=options,
                outputdir=outputdir)


###
# Reduce P2VM
###
def reduce_p2vm_raw(raw, calibs, outputdir, options=None):
    # Hack for the SV new p2vms
    p2vms = [h for h in raw if h.tag == "P2VM_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_p2vm' and h.shutters == 'TTFF' and h["MJD-OBS"] <= 57550]
    p2vms += [h for h in raw if h.tag == "P2VM_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_p2vmWAVESC' and h.shutters == 'TTFF' and h["MJD-OBS"] > 57550]
    p2vms += [h for h in raw if h.tag == "P2VM_RAW" and h['HIERARCH ESO TPL ID'] == 'GRAVITY_gen_cal_p2vm' and h.shutters == 'TTFF' and h["MJD-OBS"] > 57610]
    for i, p2vm in enumerate(p2vms):
        verbose_sof(i, p2vms)
        sof = []
        sof += assoc(p2vm, raw, 'DARK_RAW',   tpl_setup+opti_setup, required=1)
        sof += assoc(p2vm, raw, 'FLAT_RAW',   tpl_setup+opti_setup, required=4)
        sof += assoc(p2vm, raw, 'P2VM_RAW',   tpl_setup+opti_setup, required=6)
        sof += assoc(p2vm, raw, 'WAVE_RAW',   tpl_setup+opti_setup, required=1)
        sof += assoc(p2vm, raw, 'WAVESC_RAW', tpl_setup+opti_setup, required=1)
        sof += assoc(p2vm, calibs, "WAVE_PARAM", [], which='before')
        run_sof(sof, name=p2vm.name, recipe='gravity_p2vm', options=options,
                outputdir=outputdir)


###
# Reduce Wavelamp
###
def reduce_wavelamp_raw(raw, calibs, outputdir, options=None):
    wavelamps = [h for h in raw if h.tag == "WAVELAMP_RAW"]
    for i, wavelamp in enumerate(wavelamps):
        verbose_sof(i, wavelamps)
        sof = [wavelamp]
        sof += assoc(wavelamp, raw, 'DARK_RAW', opti_setup+det_setup+met_setup,
                     required=1)
        sof += assoc(wavelamp, calibs, "BAD",  opti_setup, required=1)
        sof += assoc(wavelamp, calibs, "FLAT", opti_setup, required=1)
        sof += assoc(wavelamp, calibs, "WAVE", opti_setup, required=1)
        sof += assoc(wavelamp, calibs, "P2VM", opti_setup, required=1)
        sof += assoc(wavelamp, calibs, "WAVE_PARAM", [])
        run_sof(sof, name=wavelamp.name, recipe='gravity_wavelamp',
                options=options, outputdir=outputdir)


###
# Reduce dispersion
###
def reduce_disp_raw(raw, calibs, outputdir, options=None):
    disps = [h for h in raw if h.tag == 'DISP_RAW' and h['HIERARCH ESO TPL EXPNO'] == 2]
    for i, disp in enumerate(disps):
        verbose_sof(i, disps)
        sof = []
        sof += assoc(disp, raw, "DISP_RAW", opti_setup+tpl_setup, which='all')
        # Search for a DARK with full-matching, if none, accept wrong FDDL.WINDOW
        sof += assoc(disp, raw, "DARK_RAW", opti_setup+det_setup+met_setup,
                     required=1)
        sof += assoc(disp, calibs, "WAVELAMP", opti_setup, required=1)
        sof += assoc(disp, calibs, "BAD",  opti_setup, required=1)
        sof += assoc(disp, calibs, "FLAT", opti_setup, required=1)
        sof += assoc(disp, calibs, "WAVE", opti_setup, required=1)
        sof += assoc(disp, calibs, "P2VM", opti_setup, required=1)
        sof += assoc(disp, calibs, "WAVE_PARAM", [], required=1)
        run_sof(sof, name=disp.name, recipe='gravity_disp', options=options,
                outputdir=outputdir)


###
#
# Reduce GRAVITY DATA (gravity_vis recipe)
###
def reduce_object_raw(raw, calibs, outputdir, options=None, overwrite=False,
                      averagesky=False, onlycreate=False):
    objects = [h for h in raw if '_SCI_RAW' in h.tag or '_CAL_RAW' in h.tag]

    ncores = options.ncores
    if ncores > 1 and not onlycreate:
        multithread = True
        cmds = []
        curdir = os.getcwd()

        def run_multithread(idx):
            usecore = (idx % ncores)
            newdir = curdir + '/.%02d' % usecore
            try:
                os.mkdir(newdir)
            except FileExistsError:
                pass
            os.chdir(newdir)

            _cmd = '../' + cmds[idx]
            print('Reducing %s, on core %i/%i'
                  % (_cmd, (idx % ncores + 1), ncores))
            os.system(_cmd)
            os.chdir(curdir)
            shutil.rmtree(newdir, ignore_errors=True)

    else:
        multithread = False

    for i, obj in enumerate(objects):
        if not onlycreate:
            verbose_sof(i, objects)
        sof = [obj]
        # Search for a SKY with full-matching, if none, accept wrong SC_DIT
        if averagesky:
            sof += assoc(obj, raw, re.sub('_.+_', '_SKY_', obj.tag),
                         opti_ft_setup + det_setup_ftdit + det_setup_ftgain + met_setup,
                         opti_sc_setup + det_setup_scdit, required=1, which='all')
        else:
            sof += assoc(obj, raw, re.sub('_.+_', '_SKY_', obj.tag),
                         opti_ft_setup + det_setup_ftdit + det_setup_ftgain + met_setup,
                         opti_sc_setup + det_setup_scdit, required=1)
        # Search for a DARK with full-matching, if none, accept wrong FT_CONFIG
        sof += assoc(obj, calibs, "DARK", opti_setup+det_setup_scdit+met_setup,
                     det_setup, required=1)
        sof += assoc(obj, calibs, "BAD", opti_setup, det_setup_ftgain, required=1)
        sof += assoc(obj, calibs, "FLAT", opti_setup, det_setup_ftgain, required=1)
        sof += assoc(obj, calibs, "WAVE", opti_setup, det_setup_ftgain, required=1)
        sof += assoc(obj, calibs, "P2VM", opti_setup, det_setup_ftgain, required=1)
        sof += assoc(obj, raw+calibs, "DISP_MODEL", [], required=1)
        sof += assoc(obj, raw+calibs, "DIAMETER_CAT", [], required=1)
        sof += assoc(obj, raw+calibs, "EOP_PARAM", [], required=1)
        sof += assoc(obj, raw+calibs, "DIODE_POSITION", [], required=1)
        sof += assoc(obj, raw+calibs, "STATIC_PARAM", [], required=1,
                     which='before')

        if multithread:
            cmd = run_sof(sof, name=obj.name, recipe='gravity_vis',
                          options=options, outputdir=outputdir,
                          overwrite=overwrite, averagesky=averagesky,
                          multithread=True)
            if cmd is not None:
                cmds.append(cmd)
        else:
            run_sof(sof, name=obj.name, recipe='gravity_vis', options=options,
                    outputdir=outputdir, overwrite=overwrite,
                    averagesky=averagesky, onlycreate=onlycreate)

    if multithread:
        Parallel(n_jobs=ncores)(delayed(run_multithread)(idx) for idx in range(len(cmds)))


def average_p2vmred(raw, outputdir, options=None, overwrite=False):
    p2vmred = [h for h in raw if re.match('.*P2VMRED.*', h.tag)]
    for i, p in enumerate(p2vmred):
        verbose_sof(i, p2vmred)
        sof = [p]
        run_sof(sof, name=p.name, recipe='gravity_vis_from_p2vmred',
                options=options, outputdir=outputdir, overwrite=overwrite)


def compute_tf(raw, calibs, outputdir, options=None, overwrite=False):
    viss = [h for h in raw if re.match('.*_CAL_VIS', h.tag)]
    for i, vis in enumerate(viss):
        verbose_sof(i, viss)
        sof = [vis]
        sof += assoc(vis, calibs, "DIAMETER_CAT", [])
        run_sof(sof, name=vis.name, recipe='gravity_viscal', options=options,
                outputdir=outputdir, overwrite=overwrite)


def calibrate_vis(raw, calibs, outputdir, options=None, overwrite=False):
    viss = [h for h in raw if re.match('.*_SCI_VIS', h.tag)]
    for i, vis in enumerate(viss):
        verbose_sof(i, viss)
        sof = [vis]
        sof += assoc(vis, calibs, 'SINGLE_CAL_TF', calib_setup, required=0, which='all')
        sof += assoc(vis, calibs, 'DUAL_CAL_TF', calib_setup, required=0, which='all')
        sof += assoc(vis, calibs, 'SINGLE_CAL_TF_VISPHI', calib_setup, required=0, which='all')
        sof += assoc(vis, calibs, 'DUAL_CAL_TF_VISPHI', calib_setup, required=0, which='all')
        run_sof(sof, name=vis.name, recipe='gravity_viscal', options=options,
                outputdir=outputdir, overwrite=overwrite)
