# -*- coding: utf-8 -*-
"""
Created on Mon Apr 11 14:38:49 2016

@author: Antoine Mérand, Pierre Kervella
"""

from astropy.io import fits
import numpy as np
from matplotlib import pyplot as plt
import itertools

from . import dpfit

def dopdFunc(uvsb, p):
    """
    dOPD model for astrometric fit in cartesian coordinates. Returns the dOPD in m

    uvsb = u[m], v[m], s(either 0 or 1), b(list of baselines 'T1T2')

    p = {'X[mas]':, 'Y[mas]':, 'ZPB_T1T2[um]':} for parametrisation per baselines
    p = {'X[mas]':, 'Y[mas]':, 'ZPT_T1[um]':, 'ZPT_T2[um]':} for parametrisation per telescopes

    or

    p = {'R[mas]':, 'Theta[deg]':, 'ZPB_T1T2[um]':} for parametrisation per baselines
    p = {'R[mas]':, 'Theta[deg]':, 'ZPT_T1[um]':, 'ZPT_T2[um]':} for parametrisation per telescopes
    """
    u,v,s,b = uvsb
    c = np.pi/(180*3600*1000.)
    res = 0.
    if 'X[mas]' in list(p.keys()) and 'Y[mas]' in list(p.keys()):
        res = u*p['X[mas]'] + v*p['Y[mas]']
    elif 'R[mas]' in list(p.keys()) and 'Theta[deg]' in list(p.keys()):
        res = u*p['R[mas]']*np.sin(p['Theta[deg]']*np.pi/180.) + \
              v*p['R[mas]']*np.cos(p['Theta[deg]']*np.pi/180.)
    res = ((-1.)**s)*c*res
    # -- list all ZP parameters:
    zpk = [k for k in list(p.keys()) if k.startswith('ZP')]
    # -- add the ZP, checking parametrisation
    res += 1e-6*np.array([p['ZPB_'+x+'[um]'] if 'ZPB_'+x+'[um]' in zpk else
                        p['ZPT_'+x[2:4]+'[um]'] - p['ZPT_'+x[0:2]+'[um]']
                        for x in b ])
    return res


# FIXME: phasor fitting function
def phasorFunc(uvsbw, p):
    """
    Phase model for astrometric fit. Returns the residual in phase
    """

    u,v,s,b,w = uvsbw
    c = np.pi/(180*3600*1000.)
    res = 0.
    if 'X[mas]' in list(p.keys()) and 'Y[mas]' in list(p.keys()):
        tmp1 = 1j * (u*p['X[mas]'] + v*p['Y[mas]']) * 2*np.pi / w
        res = np.exp(tmp1 * c)   
    elif 'R[mas]' in list(p.keys()) and 'Theta[deg]' in list(p.keys()):
        tmp1 = 1j * (u*p['R[mas]']*np.sin(p['Theta[deg]']*np.pi/180.) + \
              v*p['R[mas]']*np.cos(p['Theta[deg]']*np.pi/180.))  * 2*np.pi / w
        res = np.exp(tmp1 * c)   

    res = ((-1.)**s)*res
    # -- list all ZP parameters:
    zpk = [k for k in list(p.keys()) if k.startswith('ZP')]

    # FIXME: ADD THE ZP IN PHASE
    # -- add the ZP, checking parametrisation
    res += 1e-6*np.array([p['ZPB_'+x+'[um]'] if 'ZPB_'+x+'[um]' in zpk else
                        p['ZPT_'+x[2:4]+'[um]'] - p['ZPT_'+x[0:2]+'[um]']
                        for x in b ])
    return res


def computeZPclosures(fit):
    """
    fit is a result from dpfit.leastsqFit(dopdFunc, )

    add 'ZPC_(...)[um]' keys to fit

    """
    baselines = list(set(fit['x'][3]))
    # -- check metrology closures in case of 6zp formula
    clo = []
    for c in itertools.combinations(baselines, 3):
        tels = [x[:2] for x in c] # list first telescopes in all baselines
        tels.extend([x[2:] for x in c]) # add 2nd telescopes
        if len(set(tels))==3: # closed triangle -> baselines from 3T
            # -- find signs for closure computation
            s = [1] # first baseline is reference for signs
            for k in [1,2]: # 2nd and 3rd baselines
                if c[k][2:]==c[0][:2]:
                    s.append(1)
                elif c[k][2:]==c[0][2:]:
                    s.append(-1)
                elif c[k][:2]==c[0][2:]:
                    s.append(1)
                elif c[k][:2]==c[0][:2]:
                    s.append(-1)
            clo.append(dict(list(zip(['ZPB_'+b+'[um]' for b in c], s))))
    for c in clo:
        # -- form name of the ZP closure, with baselines and signs
        k_ = 'ZPC_['+''.join(['%s%s'%('-' if c[k]<0 else '+',
                                        k.split('ZPB_')[1].split('[')[0].strip()) for k in list(c.keys())])+'][um]'
        tmp = np.sum([c[k]*fit['best'][k] for k in list(c.keys())])
        fit[k_] = tmp
    return fit


def dopdFit(dualseries, nZP=6, polarCoord=False):
    """
    Astrometric fit of a series of dual field observations of the same binary.
    Authors: Antoine Mérand, Pierre Kervella

    nZP: number of zero points fitted, either 3 or 6 (default).

    polarCoord: use polar coordinates (R, PA) for binary separation, else Cartesian (default)

    limitation: assumes 4T

    """
    nfiles = dualseries.nfiles
    nframes = dualseries.nframes
    print(" Number of files : ", nfiles)
    print(" Number of frames: ", nframes)
    nwave_sc = dualseries.nwave_sc
    nbase = 6
    
    # -- prepare first guess parameters
    param = {}
    if polarCoord:
        param['R[mas]']     = np.sqrt(np.square(dualseries.relposX)+np.square(dualseries.relposY))
        param['Theta[deg]'] = np.arctan2(dualseries.relposY,dualseries.relposX)*180./np.pi
    else:
        param['X[mas]'] = dualseries.relposX
        param['Y[mas]'] = dualseries.relposY

    doNotFit = []
    if nZP == 6:
        for b in (dualseries.base[0,:]):
            param['ZPB_'+b+'[um]'] = 0 # 6 zero points: 1 per baselines
    elif nZP == 3:
        for i,t in enumerate([T[1][0:2] for T in list(dualseries.tel.items())]):
            param['ZPT_'+t+'[um]'] = 0 # 4 zero points: 1 per telescope
            if i==0:
                doNotFit.append('ZPT_'+t+'[um]')
    else:
        raise "I do not know how to fit %d ZP"%nZP
#
#    # -- dOPD is function of: U, V, swap state and base (for the zero point)
    uvsb = (dualseries.ucoord.flatten(),
            dualseries.vcoord.flatten(),
            dualseries.swapstate.flatten(),
            dualseries.base.flatten())
    #uvsbw = (dualseries.ucoord, dualseries.vcoord, dualseries.swapstate, dualseries.base, dualseries.wave_sc)

#    # Mask non usable baselines
#    mask = np.where(dualseries.base.flatten()!=2)
#
#    # -- dOPD is function of: U, V, swap state and base (for the zero point)
#    uvsb = (dualseries.ucoord.flatten()[mask],
#            dualseries.vcoord.flatten()[mask],
#            dualseries.swapstate.flatten()[mask],
#            dualseries.base.flatten()[mask])

    # -- fits for different formulas of dOPD
    models = ['', '_telmet']
    res = {}
    for i,m in enumerate(models):
        print('\033[%dm'%(34+i)) # colors!
        print('#'*3, 'MODEL:', m , '#'*80)

        print(np.shape(dualseries.__dict__['dopd_gd']))
        print(np.shape(uvsb))
        # -- fit dOPD from group delay only (FC met only)
        fit0 = dpfit.leastsqFit(dopdFunc, uvsb, param,
                                dualseries.__dict__['dopd_gd'].flatten(),
                                verbose=2, doNotFit=doNotFit)
        dualseries.__dict__['gdmodel'+m] = fit0['model'][:,None]
        if nZP == 6:
            print("-- ZPB closures -----")
            fit0 = computeZPclosures(fit0)
            clo = [k for k in list(fit0.keys()) if k.startswith('ZPC_')]
            for k in clo: print('%s = %5.3f [um]'%(k, fit0[k]))

        # -- compute residual (measured phasor in the data)-(phasor of GD only best fit):
        print(np.shape(dualseries.__dict__['phasor'+m]))
        print(np.shape(fit0['model']))
        print(np.shape(dualseries.wave_sc))
        complex_phasor_res = dualseries.__dict__['phasor'+m].reshape(nframes*nbase,nwave_sc) *\
                       np.exp(-1j * 2 * np.pi * fit0['model'][:,None] / dualseries.wave_sc[None,:])
        dualseries.__dict__['phasor_res'+m] = complex_phasor_res.reshape(nframes,nbase,nwave_sc)
        
        # limit indices to compute median = 20% of the total bandwidth
        minidx = int(dualseries.nwave_sc*0.4)
        maxidx = int(dualseries.nwave_sc*0.6)
        
        # Conversion of phasor in dOPD for the second fit
        phase = np.zeros((nframes,nbase))
        mean_wave_factor = np.mean(dualseries.wave_sc[minidx:maxidx] / (2*np.pi))
        print('mean wave factor',mean_wave_factor)
        for baseline in range(0,nbase):
            phase[:,baseline] = np.mean(np.angle(dualseries.__dict__['phasor_res'+m][:,baseline,minidx:maxidx]),axis=1)
            mean_phase = np.mean(phase)
            phase[:,baseline] = (phase[:,baseline] - mean_phase + np.pi)%(2*np.pi) - np.pi + mean_phase
        dopd_phase = mean_wave_factor * phase     
        dualseries.__dict__['dopd_phasor_res'+m] = dopd_phase

        # -- fit dOPD from group delay + phase (FC or FC+MET)
        fit1 = dpfit.leastsqFit(dopdFunc, uvsb, fit0['best'],
                                dualseries.__dict__['dopd_phasor_res'+m].flatten(),
                                verbose=2, doNotFit=doNotFit)

        # -- compute phasor and dopd redisuals:
        complex_phasor_res2 = dualseries.__dict__['phasor_res'+m].reshape(nframes*nbase,nwave_sc) *\
                       np.exp(-1j * 2 * np.pi * fit1['model'][:,None] / dualseries.wave_sc[None,:])
        dualseries.__dict__['phasor_res2'+m] = complex_phasor_res2.reshape(nframes,nbase,nwave_sc)
        
        phase = np.zeros((nframes,nbase))
        mean_wave_factor = np.mean(dualseries.wave_sc[minidx:maxidx] / (2*np.pi))
        for baseline in range(0,nbase):
            phase[:,baseline] = np.mean(np.angle(dualseries.__dict__['phasor_res2'+m][:,baseline,minidx:maxidx]),axis=1)
            mean_phase = np.mean(phase)
            phase[:,baseline] = (phase[:,baseline] - mean_phase + np.pi)%(2*np.pi) - np.pi + mean_phase
        dopd_phase2 = mean_wave_factor * phase
 
        dualseries.__dict__['dopd_phasor_res2'+m] = dopd_phase2
        
        if nZP == 6:
            print("-- ZPB closures -----")
            fit1 = computeZPclosures(fit1)
            clo = [k for k in list(fit1.keys()) if k.startswith('ZPC_')]
            for k in clo: print('%s = %5.3f'%(k, fit1[k]))
                
        res['GD'+m] = fit0
        res['phase2'+m] = fit1
        print('')
    print('\033[0m')
    #raise 'STOP'
    return res
