import numpy as np
import glob
from astropy.io import fits
from scipy import optimize
import matplotlib
import matplotlib.pyplot as plt
import os
import argparse


parser = argparse.ArgumentParser()

parser.add_argument('calibrator', help="File used as the phase calibrator")

parser.add_argument("-f", "--folder", help="Folder in which poscor is applied", type=str, default='./')
parser.add_argument("-ig", "--ignorefiles", help="List of files to be ignored", type=list, default=[])
parser.add_argument("-p", "--plot", help="Show plot of poscor", action='store_true', default=False)
parser.add_argument("-nos", "--nosave", help="Do nto save the data", action='store_true', default=False)
parser.add_argument("-lmax", "--lambdaMAX", help="Maximum acceptable wavelength [um]", type = float, metavar = '', default=2.441)
parser.add_argument("-lmin", "--lambdaMIN", help="Minimum acceptable wavelength [um]",type = float, metavar = '', default=2.04)
args = parser.parse_args()


class Poscor():
    def __init__(self, par):
        self.par = par
        
        allfiles = sorted(glob.glob(par.folder + '/GRAVI*dualscivis.fits'))
        allfiles += sorted(glob.glob(par.folder + '/GRAVI*dualsciviscalibrated.fits'))
        if len(allfiles) == 0:
            raise ValueError('No files found, most likely something is wrong'
                                ' with the given reduction folder')
        
        cal = par.folder + par.calibrator
        calOFFX = fits.open(cal)[0].header['ESO INS SOBJ OFFX']
        calname = fits.open(cal)[0].header['ESO INS SOBJ NAME']
        
        sci_files = []
        cal_files = []
        for file in allfiles:
            h = fits.open(file)[0].header
            if h['ESO FT ROBJ NAME'] not in  ['IRS16C', 'IRS16NW']:
                continue
            if file in par.ignorefiles:
                continue
            if h['ESO INS SOBJ OFFX'] == calOFFX and  h['ESO INS SOBJ NAME'] == calname:
                cal_files.append(file)
            else:
                sci_files.append(file)

        print('%i SCI files \n%i CAL files' 
                % (len(sci_files), len(cal_files)))
        self.cal_files = cal_files
        self.sci_files = sci_files
        self.ndit = len(fits.open(self.cal_files[0])['OI_VIS', 11].data['TIME'])//6
        print('NDIT: %i' % self.ndit)
        print()
        
        try:
            self.caldx = cal_files.index(cal)
        except ValueError:
            print()
            print('Calibrator files:')
            print(cal_files)
            print('Given Phase calibrator:')
            print(cal)
            print()
            print()
            raise ValueError('Phase calibrator not in list of calibrator files')
        
    def process_night(self):
        ndit = self.ndit
        cal_files = self.cal_files
        sci_files = self.sci_files
        try:
            d = fits.open(sci_files[0])
        except IndexError:
            d = fits.open(cal_files[0])
        wave = d['OI_WAVELENGTH', 11].data['EFF_WAVE']
        self.wave = wave
        nchannel = len(wave)
        
        try:
            chmin = np.max(np.where(wave<self.par.lambdaMIN*1e-6)) + 1
        except ValueError:
            chmin = 0
        try:
            chmax = np.min(np.where(wave>self.par.lambdaMAX*1e-6))
        except ValueError:
            chmax = nchannel
        print('Fitting between channel %i and %i' % (chmin, chmax))

        sci_t = np.zeros((len(sci_files), ndit))*np.nan
        sci_u = np.zeros((len(sci_files), ndit*6, nchannel))*np.nan
        sci_v = np.zeros((len(sci_files), ndit*6, nchannel))*np.nan
        sci_visphi_p1 = np.zeros((len(sci_files), ndit*6, nchannel))*np.nan
        sci_visphi_p2 = np.zeros((len(sci_files), ndit*6, nchannel))*np.nan
        sci_visphi_err_p1 = np.zeros((len(sci_files), ndit*6, nchannel))*np.nan
        sci_visphi_err_p2 = np.zeros((len(sci_files), ndit*6, nchannel))*np.nan
        for fdx, file in enumerate(sci_files):
            d = fits.open(file)
            h = d[0].header
            sci_t[fdx] = d['OI_VIS', 12].data['MJD'][::6]

            U = d['OI_VIS', 11].data['UCOORD']
            V = d['OI_VIS', 11].data['VCOORD']
#            sci_u_raw[fdx] = U
#            sci_v_raw[fdx] = V
            U_as = np.zeros((ndit*6, nchannel))
            V_as = np.zeros((ndit*6, nchannel))
            for bl in range(ndit*6):
                for wdx, wl in enumerate(wave):
                    U_as[bl, wdx] = U[bl]/wl * np.pi / 180. / 3600./1000
                    V_as[bl, wdx] = V[bl]/wl * np.pi / 180. / 3600./1000
            sci_u[fdx] = U_as
            sci_v[fdx] = V_as

            sci_visphi_p1[fdx] = d['OI_VIS', 11].data['VISPHI']
            sci_visphi_p2[fdx] = d['OI_VIS', 12].data['VISPHI']
            sci_visphi_err_p1[fdx] = d['OI_VIS', 11].data['VISPHIERR']
            sci_visphi_err_p2[fdx] = d['OI_VIS', 12].data['VISPHIERR']
            flag1 = d['OI_VIS', 11].data['FLAG']
            flag2 = d['OI_VIS', 12].data['FLAG']
            sci_visphi_p1[fdx][np.where(flag1 == True)] = np.nan
            sci_visphi_p2[fdx][np.where(flag2 == True)] = np.nan
            sci_visphi_err_p1[fdx][np.where(flag1 == True)] = np.nan
            sci_visphi_err_p2[fdx][np.where(flag2 == True)] = np.nan

        cal_t = np.zeros((len(cal_files), ndit))
        cal_u = np.zeros((len(cal_files), ndit*6, nchannel))
        cal_v = np.zeros((len(cal_files), ndit*6, nchannel))
        cal_visphi_p1 = np.zeros((len(cal_files), ndit*6, nchannel))
        cal_visphi_p2 = np.zeros((len(cal_files), ndit*6, nchannel))
        cal_visphi_err_p1 = np.zeros((len(cal_files), ndit*6, nchannel))
        cal_visphi_err_p2 = np.zeros((len(cal_files), ndit*6, nchannel))
        for fdx, file in enumerate(cal_files):
            d = fits.open(file)
            h = d[0].header
            cal_t[fdx] = d['OI_VIS', 12].data['MJD'][::6]

            U = d['OI_VIS', 11].data['UCOORD']
            V = d['OI_VIS', 11].data['VCOORD']
            U_as = np.zeros((ndit*6, nchannel))
            V_as = np.zeros((ndit*6, nchannel))
            for bl in range(ndit*6):
                for wdx, wl in enumerate(wave):
                    U_as[bl, wdx] = U[bl]/wl * np.pi / 180. / 3600./1000
                    V_as[bl, wdx] = V[bl]/wl * np.pi / 180. / 3600./1000
            cal_u[fdx] = U_as
            cal_v[fdx] = V_as

            cal_visphi_p1[fdx] = d['OI_VIS', 11].data['VISPHI']
            cal_visphi_p2[fdx] = d['OI_VIS', 12].data['VISPHI']
            cal_visphi_err_p1[fdx] = d['OI_VIS', 11].data['VISPHIERR']
            cal_visphi_err_p2[fdx] = d['OI_VIS', 12].data['VISPHIERR']
            flag1 = d['OI_VIS', 11].data['FLAG']
            flag2 = d['OI_VIS', 12].data['FLAG']
            cal_visphi_p1[fdx][np.where(flag1 ==True)] = np.nan
            cal_visphi_p2[fdx][np.where(flag2 ==True)] = np.nan
            cal_visphi_err_p1[fdx][np.where(flag1 ==True)] = np.nan
            cal_visphi_err_p2[fdx][np.where(flag2 ==True)] = np.nan

        try:
            tstart = np.nanmin((np.nanmin(cal_t), np.nanmin(sci_t)))
        except ValueError:
            tstart = np.nanmin(cal_t)
        cal_t = (cal_t-tstart)*24*60
        sci_t = (sci_t-tstart)*24*60


        cal_B = np.zeros((cal_u.shape[0], 6, nchannel, 2))
        SpFreq = np.zeros((cal_u.shape[0], 6, nchannel))
        for bl in range(6):
            if ndit == 1:
                cal_B[:, bl, :, 0] = cal_u[:, bl, :]
                cal_B[:, bl, :, 1] = cal_v[:, bl, :]
                SpFreq[:, bl, :] = np.sqrt(cal_u[:, bl, :]**2
                                            + cal_v[:, bl, :]**2)
            else:
                cal_u[cal_u == 0] = np.nan
                cal_v[cal_v == 0] = np.nan
                cal_B[:, bl, :, 0] = np.nanmean(cal_u[:, bl::6, :], 1)
                cal_B[:, bl, :, 1] = np.nanmean(cal_v[:, bl::6, :], 1)
                SpFreq[:, bl, :] = np.sqrt(np.nanmean(cal_u[:, bl::6, :], 1)**2
                                            + np.nanmean(cal_v[:, bl::6, :], 1)**2)

        cal_dB = np.copy(cal_B)
        cal_dB = cal_dB - cal_B[self.caldx]

        dB1 = np.transpose([cal_dB[:, :, :, 0].flatten(),
                            cal_dB[:, :, :, 1].flatten()])

        nfiles = len(cal_visphi_p1)
        cal_visphi_fit = np.zeros((nfiles, 6, nchannel))
        cal_visphi_err_fit = np.zeros((nfiles, 6, nchannel))
        if ndit == 1:
            for bl in range(6):
                cal_visphi_fit[:, bl, :] = cal_visphi_p1[:, bl, :]
                cal_visphi_fit[:, bl, :] += cal_visphi_p2[:, bl, :]
                cal_visphi_err_fit[:, bl, :] = (np.sqrt(cal_visphi_err_p1[:, bl, :]**2
                                                        + cal_visphi_err_p2[:, bl, :]**2)
                                                / np.sqrt(2))
        else:
            for bl in range(6):
                cal_visphi_fit[:, bl, :] = np.nanmean(cal_visphi_p1[:, bl::6, :],1)
                cal_visphi_fit[:, bl, :] += np.nanmean(cal_visphi_p2[:, bl::6, :],1)
                cal_visphi_err_fit[:, bl, :] = (np.sqrt(np.nanmean(cal_visphi_err_p1[:, bl::6, :], 1)**2
                                                        + np.nanmean(cal_visphi_err_p2[:, bl::6, :], 1)**2)
                                                / np.sqrt(2))

        cal_visphi_fit /= 2
        cal_visphi_fit[:, :, :chmin] = np.nan
        cal_visphi_fit[:, :, chmax:] = np.nan

        Vphi_err = cal_visphi_err_fit.flatten()
        Vphi = cal_visphi_fit.flatten()

        Vphi2 = Vphi[~np.isnan(Vphi)]/360
        Vphi_err2 = Vphi_err[~np.isnan(Vphi)]/360
        dB2 = np.zeros((len(Vphi2),2))
        dB2[:, 0] = dB1[~np.isnan(Vphi), 0]
        dB2[:, 1] = dB1[~np.isnan(Vphi), 1]

        def f(dS):
            Chi = (Vphi2-np.dot(dB2, dS)) / (Vphi_err2)
            return Chi

        try:
            dS, pcov, infodict, errmsg, success = optimize.leastsq(f, x0=[0, 0], full_output=1)
        except TypeError:
            print('PosCor failed')
            dS = [0, 0]

        print('Applied poscor: (%.3f,%.3f) mas ' % (dS[0], dS[1]))
        print('Chi2 of poscor: %.2f \n' % np.sum(f(dS)**2))
        self.dS = dS

        if self.par.plot:
            n = nfiles
            par = np.linspace(0, np.max(cal_t)+10, 100)
            norm = matplotlib.colors.Normalize(vmin=np.min(par),
                                                vmax=np.max(par))
            c_m = plt.cm.inferno
            s_m = matplotlib.cm.ScalarMappable(cmap=c_m, norm=norm)
            s_m.set_array([])

            fitres = np.dot(dB1, dS)*360
            fitres_r = np.reshape(fitres, (n, 6, len(SpFreq[0, 0])))
            for idx in range(n):
                for bl in range(6):
                    if ndit == 1:
                        _plott = cal_t[idx]
                    else:
                        _plott = [np.mean(cal_t[idx])]
                    plt.errorbar(SpFreq[idx, bl, 2:-2] * 1000,
                                    cal_visphi_fit[idx, bl, 2:-2],
                                    cal_visphi_err_fit[idx, bl, 2:-2],
                                    ls='', marker='o',
                                    color=s_m.to_rgba(_plott)[0], alpha=0.5)
                    plt.plot(SpFreq[idx, bl]*1000, fitres_r[idx, bl],
                             color=s_m.to_rgba(_plott)[0])
            plt.ylim(-74, 74)
            plt.colorbar(s_m, label='Time [min]')
            plt.xlabel('Spatial frequency [1/as]')
            plt.ylabel('Visibility phase [deg]')
            plt.title('Poscor:  (%.3f,%.3f) mas ' % (dS[0], dS[1]))
            plt.show()

        cal_B = np.zeros((cal_u.shape[0], ndit*6, nchannel, 2))
        cal_B[:, :, :, 0] = cal_u
        cal_B[:, :, :, 1] = cal_v
        cal_dB = np.copy(cal_B)

        B_calib = np.zeros((6, nchannel, 2))
        for bl in range(6):
            B_calib[bl] = np.nanmean(cal_B[self.caldx][bl::6], 0)
            cal_dB[:, bl::6, :, :] = cal_dB[:, bl::6, :, :] - B_calib[bl]

        sci_B = np.zeros((sci_u.shape[0], ndit*6, nchannel, 2))
        sci_B[:, :, :, 0] = sci_u
        sci_B[:, :, :, 1] = sci_v
        sci_dB = np.copy(sci_B)

        B_calib = np.zeros((6, nchannel, 2))
        for bl in range(6):
            B_calib[bl] = np.nanmean(cal_B[self.caldx][bl::6], 0)
            sci_dB[:, bl::6, :, :] = sci_dB[:, bl::6, :, :] - B_calib[bl]

        sci_visphi_p1 -= np.dot(sci_dB, dS)*360
        sci_visphi_p2 -= np.dot(sci_dB, dS)*360

        cal_visphi_p1 -= np.dot(cal_dB, dS)*360
        cal_visphi_p2 -= np.dot(cal_dB, dS)*360
        
        sci_visphi_p1 = ((sci_visphi_p1+180) % 360) - 180
        sci_visphi_p2 = ((sci_visphi_p2+180) % 360) - 180
        cal_visphi_p1 = ((cal_visphi_p1+180) % 360) - 180
        cal_visphi_p2 = ((cal_visphi_p2+180) % 360) - 180
        
        if not self.par.nosave:
            savefolder = self.par.folder + 'poscor/'
            
            print('Saving poscored data in %s' % savefolder)
            if not os.path.isdir(savefolder):
                os.mkdir(savefolder)
            for fdx, file in enumerate(sci_files):
                fname = file[file.find('GRAVI'):]
                visphi_p1 = sci_visphi_p1[fdx]
                visphi_p2 = sci_visphi_p2[fdx]

                if np.isnan(visphi_p1).all():
                    print('%s is all nan' % fname)
                else:
                    d = fits.open(file)
                    d['OI_VIS', 11].data['VISPHI'] = visphi_p1
                    d['OI_VIS', 12].data['VISPHI'] = visphi_p2
                    d.writeto(savefolder+fname, overwrite=True)

            for fdx, file in enumerate(cal_files):
                fname = file[file.find('GRAVI'):]
                visphi_p1 = cal_visphi_p1[fdx]
                visphi_p2 = cal_visphi_p2[fdx]
                if np.isnan(visphi_p1).all():
                    print('%s is all nan' % fname)
                else:
                    d = fits.open(file)
                    d['OI_VIS', 11].data['VISPHI'] = visphi_p1
                    d['OI_VIS', 12].data['VISPHI'] = visphi_p2
                    d.writeto(savefolder+fname, overwrite=True)
