from astropy.io import fits
import numpy as np
from matplotlib import pyplot as plt
import os
import itertools

import dpfit

layout = {'A0':(-32.0010, -48.0130, -14.6416, -55.8116, 129.8495),
          'A1':(-32.0010, -64.0210, -9.4342, -70.9489, 150.8475),
          'B0':(-23.9910, -48.0190, -7.0653, -53.2116, 126.8355),
          'B1':(-23.9910, -64.0110, -1.8631, -68.3338, 142.8275),
          'B2':(-23.9910, -72.0110, 0.7394, -75.8987, 150.8275, ),
          'B3':(-23.9910, -80.0290, 3.3476, -83.4805, 158.8455),
          'B4':(-23.9910, -88.0130, 5.9449, -91.0303, 166.8295),
          'B5':(-23.9910, -96.0120, 8.5470, -98.5942, 174.8285),
          'C0':(-16.0020, -48.0130, 0.4872, -50.6071, 118.8405),
          'C1':(-16.0020, -64.0110, 5.6914, -65.7349, 134.8385),
          'C2':(-16.0020, -72.0190, 8.2964, -73.3074, 142.8465),
          'C3':(-16.0020, -80.0100, 10.8959, -80.8637, 150.8375),
          'D0':(0.0100, -48.0120, 15.6280, -45.3973, 97.8375),
          'D1':(0.0100, -80.0150, 26.0387, -75.6597, 134.8305),
          'D2':(0.0100, -96.0120, 31.2426, -90.7866, 150.8275),
          'E0':(16.0110, -48.0160, 30.7600, -40.1959, 81.8405),
          'G0':(32.0170, -48.0172, 45.8958, -34.9903, 65.8357),
          'G1':(32.0200, -112.0100, 66.7157, -95.5015, 129.8255),
          'G2':(31.9950, -24.0030, 38.0630, -12.2894, 73.0153),
          'H0':(64.0150, -48.0070, 76.1501, -24.5715, 58.1953),
          'I1':(72.0010, -87.9970, 96.7106, -59.7886, 111.1613),
          'I2':(80, -24, 83.456, 3.330, 90),# -- XY are correct, A0 is guessed!
          'J1':(88.0160, -71.9920, 106.6481, -39.4443, 111.1713),
          'J2':(88.0160, -96.0050, 114.4596, -62.1513, 135.1843),
          'J3':(88.0160, 7.9960, 80.6276, 36.1931, 124.4875),
          'J4':(88.0160, 23.9930, 75.4237, 51.3200, 140.4845),
          'J5':(88.0160, 47.9870, 67.6184, 74.0089, 164.4785),
          'J6':(88.0160, 71.9900, 59.8101, 96.7064, 188.4815),
          'K0':(96.0020, -48.0060, 106.3969, -14.1651, 90.1813),
          'L0':(104.0210, -47.9980, 113.9772, -11.5489, 103.1823),
          'M0':(112.0130, -48.0000, 121.5351, -8.9510, 111.1763),
          'U1':(-16.0000, -16.0000, -9.9249, -20.3346, 189.0572),
          'U2':(24.0000, 24.0000, 14.8873, 30.5019, 190.5572),
          'U3':(64.0000, 48.0000, 44.9044, 66.2087, 199.7447),
          'U4':(112.0000, 8.0000, 103.3058, 43.9989, 209.2302)}

def dopdFunc(uvsb, p):
    """
    dOPD model for astrometric fit
    """
    u,v,s,b = uvsb
    c = np.pi/(180*3600*1000.)
    res = u*p['X'] + v*p['Y']
    res = (-1.)**s*c*res
    res += np.array([p['zero'+x] for x in b])
    return res

# == files to open ================================
directory = '/Volumes/Extra128/'
files = os.listdir(directory)
# -- filter files:
files = [f for f in files if f.endswith('p2vmreduced.fits') and
                         not f.startswith('._')]
files = [f for f in files if '04-10' in f]

dates = [f.split('.')[1].split('_')[0] for f in files]
files = np.array(files)[np.argsort(dates)]
files = files[3:] # some bad files to be removed

# ==  read data ======================================
data = {'SWAP':np.array([])}
param = {} # first guess parameters for the astrometric fit

for filename in files:
    f = fits.open(os.path.join(directory, filename))
    print(filename, f[0].header['HIERARCH ESO INS SOBJ NAME'])
    # -- convertion STA_INDEX -> STA_NAME
    TEL = dict(list(zip(f['OI_ARRAY'].data['STA_INDEX'],
                   f['OI_ARRAY'].data['STA_NAME'])))
    # -- given separation in the header
    param['X'] = f[0].header['HIERARCH ESO INS SOBJ X']
    param['Y'] = f[0].header['HIERARCH ESO INS SOBJ Y']
    wl_sc = f[6].data['EFF_WAVE']
    # -- copy all data
    keys = list(f[11].header.keys())
    keys = [k for k in keys if k.startswith('TTYPE')]
    for i, k in enumerate(keys):
        _k = f[11].header[k].strip()
        tmp = f[11].data[_k].copy()
        if _k in list(data.keys()):
            data[_k] = np.append(data[_k], tmp, axis=0)
        else:
            data[_k] = tmp

    # -- compute rank, i.e. fringe order:
    rank_sc = np.floor(f[11].data['GDELAY'][:,None]/wl_sc[None,:] -
                    1*(f[11].data['GDELAY']<0)[:,None])
    if 'RANK_SC' in list(data.keys()):
        data['RANK_SC'] = np.append(data['RANK_SC'], rank_sc, axis=0)
    else:
        data['RANK_SC'] = rank_sc

    # -- same for the FT
    rank_ft = np.floor(f[11].data['GDELAY_FT'][:,None]/wl_sc[None,:] -
                    1*(f[11].data['GDELAY_FT']<0)[:,None])
    if 'RANK_FT' in list(data.keys()):
        data['RANK_FT'] = np.append(data['RANK_FT'], rank_ft, axis=0)
    else:
        data['RANK_FT'] = rank_ft

    # -- elaborate DOPD
    tmp = wl_sc[None,:]/(2*np.pi)*np.angle(f[11].data['VISDATA']) + \
            +rank_sc*wl_sc[None,:] + \
            -wl_sc[None,:]/(2*np.pi)*f[11].data['PHASE_REF'] + \
            -rank_ft*wl_sc[None,:] + \
            f[11].data['OPD_MET_FC'][:,None]
    if 'DOPD' in list(data.keys()):
        data['DOPD'] = np.append(data['DOPD'], tmp, axis=0)
    else:
        data['DOPD'] = tmp

    # -- Naive DOPD, based on group delay only
    tmp = f[11].data['GDELAY'] - f[11].data['GDELAY_FT'] + f[11].data['OPD_MET_FC']
    if 'DOPDnaive' in list(data.keys()):
        data['DOPDnaive'] = np.append(data['DOPDnaive'], tmp, axis=0)
    else:
        data['DOPDnaive'] = tmp
    # -- baseline name for each
    tmp = np.array([TEL[s[0]]+ TEL[s[1]] for s in f[11].data['STA_INDEX']])
    if 'BASE' in list(data.keys()):
        data['BASE'] = np.append(data['BASE'], tmp, axis=0)
    else:
        data['BASE'] = tmp

    # -- swapped or not?
    swap = 'YES' in f[0].header['HIERARCH ESO INS SOBJ SWAP']
    data['SWAP'] = np.append(data['SWAP'] , np.ones(f[11].header['NAXIS2'])*float(swap))

    # -- close fits file
    f.close()


# -- astrometric fit with all data
fitOnly=[]
for b in set(data['BASE']):
    # -- add a zero point in the parameters for each baseline
    param['zero'+b] = 0

N=6 # number of bases per files (6 for 4T)
# -- dOPD is function of: U, V, spwap state and base (for the zero point)
uvsb = (data['UCOORD'], data['VCOORD'], data['SWAP'], data['BASE'])
# -- average DOPD allong wavelength
data['DOPD_AVG'] = data['DOPD'].mean(axis=1)
# -- what key will we fit?
#fitter = 'DOPDnaive'
fitted = 'DOPD_AVG'

# -- fit of all Data
fitG = dpfit.leastsqFit(dopdFunc, uvsb, param, data[fitted], verbose=1)
# -- best fit model of the DOPD:
mod = fitG['model']

# -- plots
plt.close('all')
plt.figure(0, figsize=(14,9))
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)

for b in range(N):
    if b==0:
        ax0=plt.subplot(3,2,1+b)
    else:
        plt.subplot(3,2,1+b, sharex=ax0, sharey=ax0)
    plt.title(data['BASE'][b])
    # plt.scatter(data['MJD'][b::N], 1e6*(data['DOPD_AVG'][b::N] - fitG['best']['zero'+data['BASE'][b]])*
    #        (-1.)**data['SWAP'][b::N], marker='o', c=data['SWAP'][b::N], cmap='RdYlGn',
    #        alpha=0.4, lw=0)
    # plt.plot(data['MJD'][b::N], 1e6*(mod[b::N]- fitG['best']['zero'+data['BASE'][b]])*
    #        (-1.)**data['SWAP'][b::N], '+k', alpha=0.1)
    #ax.set_xticks([])
    #ax = plt.subplot(3,2,N+1+b)
    # -- residuals for DOPD_AVG
    plt.scatter(data['MJD'][b::N], 1e6*(data['DOPD_AVG'][b::N]-mod[b::N]), marker='o',
                c = data['SWAP'][b::N], cmap='RdYlGn', alpha=0.4, lw=0)
    # -- residuals for DOPDnaive

    plt.plot(data['MJD'][b::N], 1e6*(data['DOPDnaive'][b::N]-mod[b::N]), '+k')

    plt.grid()
ax0.set_xticks([])

# === checking closures:
for c in itertools.combinations(set(data['BASE']), 3):
    tels = [x[:2] for x in c]
    tels.extend([x[2:] for x in c])
    if len(set(tels))==3:
        print(c)
        #plt.plot(data['MJD'][b::N], )

# === compare global fit with subsets fits
plt.close(2)
plt.figure(2, figsize=(14,4))
plt.subplots_adjust(left=0.05, right=0.95, top=0.92, bottom=0.12)
theta = np.linspace(0,2*np.pi,100)
i = fitG['fitOnly'].index('X')
j = fitG['fitOnly'].index('Y')
sMa, sma, a = dpfit._ellParam(fitG['cov'][i,i], fitG['cov'][j,j], fitG['cov'][i,j])
Xg,Yg = sMa*np.cos(theta), sma*np.sin(theta)
Xg,Yg = Xg*np.cos(a)+Yg*np.sin(a),-Xg*np.sin(a)+Yg*np.cos(a)

colors=['r', 'g', 'b', 'orange']

# == fork fits: one TEL connects to ALL
ax0 = plt.subplot(131)
ax0.set_title('Forks')
ax0.set_aspect('equal')
ax0.set_xlabel('X offset / first guess (mas)')
ax0.set_ylabel('X offset / first guess (mas)');

ax0.plot(fitG['best']['X']+Xg-param['X'], fitG['best']['Y']+Yg-param['Y'], '-k',
        label='global Fit', linewidth=3, alpha=0.5)
c=0
for i,t in list(TEL.items()):
    w = np.array([t in d for d in data['BASE']])
    fitOnly = [k for k in list(param.keys()) if t in k]
    fitOnly.extend(['X', 'Y'])
    uvsb = (data['UCOORD'][w], data['VCOORD'][w], data['SWAP'][w], data['BASE'][w])
    fit = dpfit.leastsqFit(dopdFunc, uvsb, param, data[fitted][w], fitOnly=fitOnly)
    i = fit['fitOnly'].index('X')
    j = fit['fitOnly'].index('Y')
    sMa, sma, a = dpfit._ellParam(fit['cov'][i,i], fit['cov'][j,j], fit['cov'][i,j])
    X,Y = sMa*np.cos(theta), sma*np.sin(theta)
    X,Y = X*np.cos(a)+Y*np.sin(a),-X*np.sin(a)+Y*np.cos(a)
    ax0.plot(fit['best']['X']+X-param['X'], fit['best']['Y']+Y-param['Y'], '-',
            label='Fork '+t, color=colors[c])
    c+=1
ax0.legend(fontsize=8)

# == trowel fits: ALL TEL but one (triangle)
ax1 = plt.subplot(132, sharex=ax0, sharey=ax0)
ax1.set_title('Trowel')
ax1.set_aspect('equal')
ax1.set_xlabel('X offset / first guess (mas)')

plt.plot(fitG['best']['X']+Xg-param['X'], fitG['best']['Y']+Yg-param['Y'], '-k',
        label='global Fit', linewidth=3, alpha=0.5)
c = 0
for i,t in list(TEL.items()):
    w = np.array([t not in d for d in data['BASE']])
    fitOnly = [k for k in list(param.keys()) if t not in k]
    uvsb = (data['UCOORD'][w], data['VCOORD'][w], data['SWAP'][w], data['BASE'][w])
    fit = dpfit.leastsqFit(dopdFunc, uvsb, param, data[fitted][w], fitOnly=fitOnly)
    i = fit['fitOnly'].index('X')
    j = fit['fitOnly'].index('Y')
    sMa, sma, a = dpfit._ellParam(fit['cov'][i,i], fit['cov'][j,j], fit['cov'][i,j])
    X,Y = sMa*np.cos(theta), sma*np.sin(theta)
    X,Y = X*np.cos(a)+Y*np.sin(a),-X*np.sin(a)+Y*np.cos(a)
    plt.plot(fit['best']['X']+X-param['X'], fit['best']['Y']+Y-param['Y'], '-',
            label='w/o '+t, color=colors[c])
    c+=1

plt.legend(fontsize=8)

# == independant pairs fits
ax2 = plt.subplot(133, sharex=ax0, sharey=ax0)
ax2.set_aspect('equal')
ax2.set_xlabel('X offset / first guess (mas)')
ax2.set_title('Chop Sticks')
colors = [(0.6,0,0.8), (0.6,0.8,0), (0,0.6,0.8)]
c = 0
plt.plot(fitG['best']['X']+Xg-param['X'], fitG['best']['Y']+Yg-param['Y'], '-k',
        label='global Fit', linewidth=3, alpha=0.5)
for b1 in set(data['BASE']):
    # -- find othe baseline without any telescope in common
    for _b in set(data['BASE']):
        if not b1[:2] in _b and not b1[2:] in _b:
            b2 = _b
    if b1<b2:
        ## -- angle between 2 baselines?
        B1 = layout[b1[:2]][2]-layout[b1[2:]][2], layout[b1[:2]][3]-layout[b1[2:]][3]
        B1 = np.array(B1)
        B1 /= np.sqrt((B1**2).sum())
        B2 = layout[b2[:2]][2]-layout[b2[2:]][2], layout[b2[:2]][3]-layout[b2[2:]][3]
        B2 = np.array(B2)
        B2 /= np.sqrt((B2**2).sum())
        ang = np.rad2deg(np.arccos(B1[0]*B2[0] + B1[1]*B2[1]))
        w = np.array([b==b1 or b==b2 for b in data['BASE']])
        fitOnly = [k for k in list(param.keys()) if b1 in k or b2 in k]
        fitOnly.extend(['X', 'Y'])
        uvsb = (data['UCOORD'][w], data['VCOORD'][w], data['SWAP'][w], data['BASE'][w])
        fit = dpfit.leastsqFit(dopdFunc, uvsb, param, data[fitted][w], fitOnly=fitOnly)
        i = fit['fitOnly'].index('X')
        j = fit['fitOnly'].index('Y')
        sMa, sma, a = dpfit._ellParam(fit['cov'][i,i], fit['cov'][j,j], fit['cov'][i,j])
        X,Y = sMa*np.cos(theta), sma*np.sin(theta)
        X,Y = X*np.cos(a)+Y*np.sin(a),-X*np.sin(a)+Y*np.cos(a)
        plt.plot(fit['best']['X']+X-param['X'], fit['best']['Y']+Y-param['Y'], '-',
                label=b1+'/'+b2+' %4.0fdeg'%ang, color=colors[c])
        c+=1
plt.legend(fontsize=8)
ax2.set_xlim(-12,5)
ax2.set_ylim(-9,7)
