# -*- coding: utf-8 -*-
"""
Created on Mon Sep 14 11:41:26 2015

@author: kervella
"""
try:
   from astropy.io import fits as pyfits
except:
   import pyfits
import numpy as np
from . import gravi_astrometry_lib
from matplotlib import pyplot as plt

# ===============================================================================
# Python classes to read GRAVITY files for astrometry
# ===============================================================================
#
# Dispmodel, Astroreduced
#

class Dispmodel:
     def __init__(self, filename):
         dispmodel = pyfits.open(filename+".fits",memmap=True)
         # self.header = dispmodel[0].header.copy() # The main header is mostly empty
         self.header = dispmodel['DISP_MODEL'].header
         self.sta_index = dispmodel['DISP_MODEL'].data['STA_INDEX']
         self.n_mean = dispmodel['DISP_MODEL'].data['N_MEAN']
         self.n_diff = dispmodel['DISP_MODEL'].data['N_DIFF']
         self.n_sc_second = dispmodel['DISP_MODEL'].data['N_SC_SECOND']
         self.n_ft_second = dispmodel['DISP_MODEL'].data['N_FT_SECOND']
         self.lin_fddl_sc = dispmodel['DISP_MODEL'].data['LIN_FDDL_SC']
         self.lin_fddl_ft = dispmodel['DISP_MODEL'].data['LIN_FDDL_FT']
         self.lbd = dispmodel['DISP_MODEL'].data['LBD']
         dispmodel.close()

class Astroreduced:
    def __init__(self, filelist, fits_keywords):
        self.headers = []
        self.filename = []
        self.oiarrays = []
        self.wave_sc = []
        self.swapstate = []
        self.ucoord = []
        self.vcoord = []
        self.e_u = []
        self.e_v = []
        self.e_w = []
        self.opd_met_fc = []
        self.phase_sc = []
        self.rank = []
        self.nframe_sc = []
        # nfiles_in = len(filelist)
        ntel = 4
        nbase = 6


        # ===========
        # Beginning of loop on files

        nfiles = 0
        for filename in filelist:
            print(("Check file: "+filename))
            gravi_file = pyfits.open(filename+".fits",memmap=True)

            header = gravi_file[0].header.copy()
            gravi_file.close()
            del gravi_file

            # Data files with the correct keywords only
            key_names = list(fits_keywords.keys())

            type_ok = True
            for strname in key_names:
                type_ok *= (strname in header)
                type_ok *= ( 'ESO INS SOBJ SWAP' in header )

            if type_ok == False:
               print ("Type not OK... continue")
               continue

            keys_ok = True
            for name in key_names:
                keys_ok *= (header[name] in fits_keywords[name])

            if keys_ok == False:
               print ("Key or Type not OK... continue")
               continue

            print(' Loading ASTROREDUCED file %02i: %s.fits'%(nfiles+1,filename))

            hdulist = pyfits.open(filename+".fits")

            swapsign = int(hdulist[0].header['HIERARCH ESO INS SOBJ SWAP'] == 'YES')

            # relative position of the two objects as specified by user taking swap into account
            self.relposX = hdulist[0].header['HIERARCH ESO INS SOBJ X'] * swapsign
            self.relposY = hdulist[0].header['HIERARCH ESO INS SOBJ Y'] * swapsign

            self.headers.append(hdulist[0].header)
            self.filename.append(filename+".fits")
            self.oiarrays.append(hdulist['OI_ARRAY'].data)

            TEL = dict(list(zip(hdulist['OI_ARRAY'].data['STA_INDEX'],
                           hdulist['OI_ARRAY'].data['STA_NAME'])))
            self.tel = TEL
            self.staxyz = hdulist['OI_ARRAY'].data['STAXYZ']
            print(self.staxyz)
            
            self.basevector = np.zeros((ntel,ntel,3))
            for tel1 in range(0,ntel):
                for tel2 in range(0,ntel):
                    self.basevector[tel1,tel2,0] = -(self.staxyz[tel2, 0] - self.staxyz[tel1, 0])
                    self.basevector[tel1,tel2,1] = -(self.staxyz[tel2, 1] - self.staxyz[tel1, 1])
                    self.basevector[tel1,tel2,2] = +(self.staxyz[tel2, 2] - self.staxyz[tel1, 2])
            print(self.basevector)

            if "COMBINED" in str(hdulist[0].header['HIERARCH ESO FT POLA MODE']):
                self.polarsplit = False
                extv = 4 # combined polarization
            else:
                self.polarsplit = True
                # 10 = SPECTRO_SC, 11 = SPECTRO_SC_P1, 13 = SPECTRO_SC_P2
                extv = 5 # FIXME: only the 1st polarization is considered at the moment

            # For both polarization modes.
            # The wave scales of both polarizations of SC must be the same.
            for hdu in hdulist :
                if hdu.name == 'OI_WAVELENGTH':
                    wl_sc = hdu.data['EFF_WAVE']
            self.wave_sc = wl_sc
            self.nwave_sc = len(wl_sc)

            # Number of SC frames in the file
            nframe_sc = hdulist[0].header['HIERARCH ESO DET2 NDIT']
            # List of SC frame numbers
            self.nframe_sc.append(nframe_sc)

            # Construction of a concatenation of all OI_VIS tables in the series
            if nfiles == 0:
                oi_vis = hdulist[extv]
                nrows1 = 0
                nrows2 = oi_vis.data.shape[0]
                self.swapstate = np.array([swapsign]*nrows2)
            else:
                nrows1 = oi_vis.data.shape[0]
                nrows2 = hdulist[extv].data.shape[0]
                oi_vis = pyfits.BinTableHDU.from_columns(oi_vis.columns, nrows=nrows1+nrows2)
                for colname in oi_vis.columns.names:
                    oi_vis.data[colname][nrows1:] = hdulist[extv].data[colname].copy()
                self.swapstate = np.append(self.swapstate, [swapsign]*nrows2, axis=0)

            # -- baseline name for each point
            tmp = np.array([str(TEL[s[0]])[0:2]+str(TEL[s[1]])[0:2] for s in hdulist[extv].data['STA_INDEX']])
            if nfiles == 0:
                self.base = tmp
            else:
                self.base = np.append(self.base, tmp, axis=0)

            nfiles = nfiles + 1
            # hdulist.close() ideally we should find a way to close the previous files

        # End loop on files
        # =================
        
        # Check if files loaded
        if (nfiles < 1):
           print ("")
           print ("No files could be loaded")
           print ("")
           raise Exception("No files could be loaded")


        # Mean wavelength of SC
        self.lambda_mean = np.mean(self.wave_sc)

        # Wavelength of metrology laser in nanometers converted to meters
        self.lambda_laser = hdulist[0].header['HIERARCH ESO INS MLC WAVELENG'] * 1E-9

        # Large table containing all columns of the original OI_VIS extension
        # This is to keep the original (not reshaped nor masked) data accessible if necessary
        self.oi_vis = oi_vis

        # Number of files
        self.nfiles = nfiles
        
        # Number of frames
        self.total_frames = len(oi_vis.data['MJD'][::nbase])

        # Swap state
        self.swapstate = self.swapstate.reshape(self.total_frames,nbase)
        
        # MJD of each frame and each baseline
        self.mjd = np.array(oi_vis.data['MJD']).reshape(self.total_frames,nbase)

        # Total number of frames
        self.total_frames = int(self.mjd.shape[0])

        # Rejection flag (bad if >0)
        self.rejection_flag = np.array(oi_vis.data['REJECTION_FLAG']).reshape(self.total_frames,nbase)

        # Baseline names
        self.base = self.base.reshape(self.total_frames,nbase)

        # Complex visibilities
        self.visdata = np.array(oi_vis.data['VISDATA']).reshape(self.total_frames,nbase,self.nwave_sc)

        # (u,v) plane coordinates
        self.ucoord = np.array(oi_vis.data['UCOORD']).reshape(self.total_frames,nbase)
        self.vcoord = np.array(oi_vis.data['VCOORD']).reshape(self.total_frames,nbase)
        
        # (eu, ev, ew) vector coordinates to the observed target
        self.e_u = np.array(oi_vis.data['E_U']).reshape(self.total_frames,nbase,3)
        self.e_v = np.array(oi_vis.data['E_V']).reshape(self.total_frames,nbase,3)
        self.e_w = np.array(oi_vis.data['E_W']).reshape(self.total_frames,nbase,3)

        # Group delay signals (in meters)
        self.gd_sc = np.array(oi_vis.data['GDELAY']).reshape(self.total_frames,nbase)
        self.gd_ft = np.array(oi_vis.data['GDELAY_FT']).reshape(self.total_frames,nbase)

        # Phase of the SC fringes vs wavelength (in radians)
        self.phase_sc = np.array(np.angle(oi_vis.data['VISDATA'])).reshape(self.total_frames,nbase,self.nwave_sc)
        
        # Phase_ref is the opposite of the phase of the FT (in radians)
        self.phase_ref = np.array(oi_vis.data['PHASE_REF']).reshape(self.total_frames,nbase,self.nwave_sc) 

        # Fiber coupler metrology
        self.opd_met_fc = np.array(oi_vis.data['OPD_MET_FC']).reshape(self.total_frames,nbase) 

        # Telescope diode metrology OPD and phasor
        self.opd_met_tel = np.array(oi_vis.data['OPD_MET_TEL']).reshape(self.total_frames,nbase,4)
        
        # Differential TEL-FC phasor
        self.phasor_met_telfc = np.array(oi_vis.data['PHASOR_MET_TELFC']).reshape(self.total_frames,nbase,4) 
        
        # Average metrology signal of the four telescope diodes
        self.opd_met_tel_mean = np.zeros((self.total_frames,nbase)) * 1j
        self.phasor_met_telfc_mean = np.ones((self.total_frames,nbase)) * 1j
        for frame in range(0,self.total_frames):
            for base in range(0,nbase):
                for diode in range(0,4): # sum phasors
                    self.opd_met_tel_mean[frame,base] = self.opd_met_tel_mean[frame,base] +\
                        np.exp(2 * np.pi * 1j * (self.opd_met_tel[frame,base,diode]/self.lambda_laser))
                    self.phasor_met_telfc_mean[frame,base] *= self.phasor_met_telfc[frame,base,diode]
                self.opd_met_tel_mean[frame,base] = np.angle(self.opd_met_tel_mean[frame,base])*self.lambda_laser/(2*np.pi)
                self.phasor_met_telfc_mean[frame,base] = np.power(self.phasor_met_telfc_mean[frame,base],1./4.)

        # Group delay including the metrology and dispersion correction
        self.gd_disp = np.array(oi_vis.data['GDELAY_DISP']).reshape(self.total_frames,nbase)
        
        # OPD_DISP(t,lbd) =  [ FDDL1(t) - FDDL2(t) + cst ] * n(lbd)
        self.opd_disp = np.array(oi_vis.data['OPD_DISP']).reshape(self.total_frames,nbase,self.nwave_sc)

        # phase including dispersion in radians
        # self.phase_disp = np.array(oi_vis.data['PHASE_DISP']).reshape(self.total_frames,nbase,self.nwave_sc)

        # Differential metrology from the telescope diodes with respect to the FC diodes
        self.met_delta_phasor = np.zeros((self.total_frames,nbase)) * 1j
        self.met_delta = np.zeros((self.total_frames,nbase)) * 1j
        for baseline in range(0,nbase):                                              
            self.met_delta_phasor[:,baseline] = np.exp(2 * np.pi * 1j *
                                              (self.opd_met_tel_mean[:,baseline] - self.opd_met_fc[:,baseline]) /
                                              self.lambda_laser)
            self.met_delta[:,baseline] = np.angle(self.met_delta_phasor[:,baseline]) * self.lambda_laser / (2*np.pi)

        # dOPD from group delay (with FC metrology only)
        self.dopd_gd = np.zeros((self.total_frames,nbase))
        self.dopd_gd[:,:] =  self.gd_ft[:,:] - self.gd_sc[:,:] + self.gd_disp[:,:]
        #self.gdmodel = [] # Model from the first round of GD only modeling (empty at init)

        # Phasor with FC metrology only:
        # PHASOR = VISDATA(lbd) * expi(2j*pi * OPD_DISP(lbd)/lbd) * expi(PHASE_REF(lbd))
#        self.phasor = np.conj(self.visdata[:,:]) *\
#            np.exp(2 * np.pi * 1j * self.opd_disp[:,:] / self.wave_sc[None,:]) *\
#            np.exp(-1j * self.phase_ref[:,:])
  
        self.phasor = np.zeros((self.total_frames,nbase,self.nwave_sc)) * 1j
        self.phasor_telmet = np.zeros((self.total_frames,nbase,self.nwave_sc)) * 1j
        for frame in range(0,self.total_frames):
            for baseline in range(0,nbase):
                # Phasor with FC metrology only:
                self.phasor[frame,baseline,:] = np.conj(self.visdata[frame,baseline,:]) *\
                        np.exp(2 * np.pi * 1j * self.opd_disp[frame,baseline,:] / self.wave_sc[None,None,:]) *\
                        np.exp(-1j * self.phase_ref[frame,baseline,:])

                # Phasor with FC+TEL metrology:
                self.phasor_telmet[frame,baseline,:] = np.conj(self.visdata[frame,baseline,:]) *\
                        np.exp(2 * np.pi * 1j * (self.opd_disp[frame,baseline,:])/self.wave_sc[None,None,:]) *\
                        np.exp(-2 * np.pi * 1j * self.met_delta[frame,baseline,None]/self.wave_sc[None,None,:]) *\
                        np.exp(-1j * self.phase_ref[frame,baseline,:])

        # Mask to keep only good records
        mask = []
        for frame in range(0,self.total_frames):
                if (sum(self.rejection_flag[frame,:]) == 0):
#                (np.max(np.abs(self.gd_ft[frame,:] - self.gd_sc[frame,:])) <= 20E-6):
                    mask.append(frame)                    
        
        # Removal of the bad values in the quantities used for the fit
        self.gd_sc = self.gd_sc[mask,:] # Time, baseline
        self.gd_ft = self.gd_ft[mask,:] # Time, baseline
        self.gd_disp = self.gd_disp[mask,:] # Time, baseline
        self.mjd = self.mjd[mask,:] # Time, baseline
        self.ucoord = self.ucoord[mask,:] # Time, baseline
        self.vcoord = self.vcoord[mask,:] # Time, baseline
        self.e_u = self.e_u[mask,:] # Time, baseline
        self.e_v = self.e_v[mask,:] # Time, baseline
        self.e_w = self.e_w[mask,:] # Time, baseline
        self.swapstate = self.swapstate[mask,:] # Time, baseline
        self.base = self.base[mask,:] # Time, baseline
        self.met_delta = self.met_delta[mask,:] # Time, baseline
        self.dopd_gd = self.dopd_gd[mask,:] # Time, baseline
        self.phasor = self.phasor[mask,:,:] # Time, baseline, wavelength
        self.phasor_telmet = self.phasor_telmet[mask,:,:] # Time, baseline, wavelength
        self.nframes = len(mask)

#        # Flattening of the variables
#        self.mjd = self.mjd.flatten()
#        self.ucoord = self.ucoord.flatten()
#        self.vcoord = self.vcoord.flatten()
#        self.swapstate = self.swapstate.flatten()
#        self.base = self.base.flatten()
#        self.met_delta = self.met_delta.flatten()
#        self.dopd_gd = self.dopd_gd.flatten()
#        self.phasor = self.phasor.reshape(len(mask)*nbase,self.nwave_sc)
#        self.phasor_telmet = self.phasor_telmet.reshape(self.nframes*nbase,self.nwave_sc)

        print(np.shape(self.mjd))
        print(np.shape(self.gd_disp))
        print(np.shape(self.phasor))
        # ======================
        if True: # Various plots     
            plt.figure()
            plt.title("met_delta")
            for baseline in range(0,nbase):
                plt.plot(self.mjd[:,baseline],1E6*self.met_delta[:,baseline],'o')
    
            plt.figure()
            plt.title("dOPD from group delay")
            for baseline in range(0,nbase):
                plt.plot(self.mjd[:,baseline],1E6*self.dopd_gd[:,baseline],'o')

        # ======================

