#! /usr/bin/env python3
# -*- coding: iso-8859-15 -*-

#==============================================================================
# CODE TO PLOT THE ASTRORED FILES
#==============================================================================

import os
import sys
from glob import glob
import matplotlib
matplotlib.use('Agg')
# matplotlib.use('Qt5Agg') # To use to show plots
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import numpy as np
from numpy import pi,linspace,sqrt,append,arange,array,zeros,ones,dot,angle,exp,conj,cos,sin
import matplotlib.pyplot as plt
from matplotlib.pyplot import plot,hist,clf,figure,savefig, tight_layout,title,xlabel,ylabel,legend,xlim,ylim,imshow
import time
matplotlib.rcParams['figure.max_open_warning'] = 0
plt.ion()
from gravi_astrored import astrored_loadData

import argparse
from argparse import Namespace


usage = """
description:
  make pdf plot for GRAVITY astroreduced data to check if everything went well
"""

examples = """
examples:
  Get help:
  run_gravi_astrored_check.py -h
  
"""

#
# Implement options
#

parser = argparse.ArgumentParser(description=usage, epilog=examples,conflict_handler="resolve",
                                 formatter_class=argparse.RawDescriptionHelpFormatter)

def metrology_plot(fig,TIME_L,DATA_L,factor=1,h_v=None,tel=None):

    TIME_min=np.array([o[0] for o in TIME_L]).min()
    TIME_max=np.array([o[-1] for o in TIME_L]).max()

    for f,TIME,DATA in zip(range(len(TIME_L)),TIME_L,DATA_L):
        if tel != None:
            DATA_D=DATA[:,tel]
        else:
            DATA_D=DATA
        for i in range(4):
            if h_v:
                off=h_v["ADDMET"][f][i]
                plot(TIME[h_v["MET_ROW"][f]:],(DATA_D[h_v["MET_ROW"][f]:,i]-DATA[h_v["MET_ROW"][f],i]+off)*factor+i*2*pi,"C%i"%i)
            else:
                plot(TIME,DATA_D[:,i]*factor+i*2*pi,"C%i"%i)


    for i in range(4):
        fig.axes[0].plot([TIME_min,TIME_max],np.ones(2)*i*2*pi,"--C%i"%i)
    fig.axes[0].plot([TIME_min,TIME_max],np.ones((2,4))*arange(4)*2*pi,"k:")
    fig.axes[0].plot([TIME_min,TIME_max],np.ones((2,5))*arange(5)*2*pi-pi,"k")
    # fig.axes[0].set_xlim([TIME_min,TIME_max])
    fig.axes[0].set_xlabel("Time (seconds)")
    fig.axes[0].set_ylabel("(radians)")
    fig.axes[0].set_title(fig.get_label())


filelist=[]
## If the user specifies a file name or wild cards ("*_0001.fits")
if len(sys.argv) > 1 :
    longnames = [f for files in sys.argv[1:] for f in glob(files)]
    filelist = [os.path.splitext(f)[0] for f in longnames]
## Processing of the full current directory
else :
    for file in os.listdir("."):
        if file.endswith("_astroreduced.fits"):
            filelist.append(os.path.splitext(file)[0])

filelist.sort() # process the files in alphabetical order
print(sys.argv)
print(filelist)

files=[f+".fits" for f in filelist]

files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP39017/2024-03-19/"
files_dir='/Users/slacour/DATA/GRAVITY/reduced'

doNotRedo=True
# files=glob(files_dir+"/*_astroreduced.fits")
doNotRedo=False

start_time=time.time()

files.sort()
if len(files) ==0:
    raise ValueError('No good files !!!')

files=np.sort(files)#[:7]
h_v=astrored_loadData.get_header(files,verbose=True)

name_unique=np.unique(h_v["NAME"])
number=np.array([np.sum(n==h_v["NAME"]) for n in name_unique])
Name_object=name_unique[number.argmax()].replace(" ","")
name_file="Acheck_"+Name_object+"_"+h_v["DATE"][0][:10]+'_astroreduced.pdf'

if os.path.isfile(name_file)&doNotRedo:    
    print("skiping becasuse file already there")
else:

    oi_flux=astrored_loadData.get_oi_flux(h_v)
    oi_vis=astrored_loadData.get_oi_vis(h_v)
    oi_acq=astrored_loadData.get_oi_acq(h_v)
    oi_met=astrored_loadData.get_oi_met(h_v)
    Lambda_laser=h_v["WAVELENG"][0]
    Nf=len(h_v['file'])
    TIME_min=np.array([o[0] for o in oi_met["TIME"]]).min()
    TIME_max=np.array([o[-1] for o in oi_met["TIME"]]).max()

    figure_list=[]

    ############################
    #%% UV plane
    ############################

    fig=figure("UV plane",figsize=(5,5),clear=True)
    figure_list+=[fig.get_label()]

    for u,v in zip(oi_vis["ucoord"],oi_vis["vcoord"]):
        for i in range(6):
            plot(u[::4,i].T*1e-6,v[::4,i].T*1e-6,'C%i'%i)
            plot(-u[::4,i].T*1e-6,-v[::4,i].T*1e-6,'C%i'%i)
        plot(0,0,'ko')
    plt.grid()
    fig.axes[0].set_title(fig.get_label())

    fig.axes[0].set_ylim([-333,333])
    fig.axes[0].set_xlim([333,-333])
    fig.axes[0].set_aspect(1)

    ############################
    #%% ACQUISITION CAMERA
    ############################

    if len(oi_acq["PUPIL_ACQ"]) > 0:
        fig=figure("Pupil Acquisition Camera",figsize=(6, 5.8/4*len(files)+1),clear=True)
        fig.suptitle(fig.get_label())
        figure_list+=[fig.get_label()]
        xlim([0,4])
        ylim([0,len(files)])
        for i in range(1,4):
            plot([i,i],[0,len(files)+1],'w')
        xlim([0,4])
        ylim([len(files),0])
        for file_i,image in enumerate(oi_acq["PUPIL_ACQ"]):
            for i in range(4):
                plt.imshow(image[i],vmin=0,extent=(i,i+1,file_i,file_i+1))
            plot([0,4],[file_i+1,file_i+1],'w')
        # fig.tight_layout()

    #%%

    if len(oi_acq["PUPIL_FIELD"]) > 0:
        fig=figure("Field Image Acquisition Camera",figsize=(6, 5.8/4*len(files)+1),clear=True)
        fig.suptitle(fig.get_label())
        figure_list+=[fig.get_label()]
        for i in range(1,4):
            plot([i,i],[0,len(files)+1],'w')
        xlim([0,4])
        ylim([len(files),0])
        for file_i,image in enumerate(oi_acq["PUPIL_FIELD"]):
            for i in range(4):
                plt.imshow(image[i],vmin=0,extent=(i,i+1,file_i,file_i+1))
            plot([0,4],[file_i+1,file_i+1],'w')
        # fig.tight_layout()


    fig=figure("OPD caused by pupil offsets",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    metrology_plot(fig,oi_met["TIME"],oi_met["OPD_FC_CORR"],factor=2*pi/Lambda_laser)
    xylims=fig.axes[0].get_xlim()
    # fig.tight_layout()

    if len(oi_acq["TIME"]) > 0:
        fig=figure("Pupil XY offsets",clear=True,figsize=(16,4.8))
        fig.subplots_adjust(hspace=0, wspace=0)
        figure_list+=[fig.get_label()]
        for T_ACQ,PUPIL_X,PUPIL_Y,NSPOT in zip(oi_acq["TIME"],oi_acq["PUPIL_X"],oi_acq["PUPIL_Y"],oi_acq["PUPIL_NSPOT"]):
            for i in range(4):
                plot(T_ACQ,PUPIL_X[:,i]+i*2*pi,"C%i"%i)
                plot(T_ACQ[NSPOT[:,i]==0],PUPIL_X[NSPOT[:,i]==0,i]+i*2*pi,"oC%i"%i,markeredgecolor="k")
                plot(T_ACQ,PUPIL_Y[:,i]+i*2*pi,"--C%i"%i)
                plot(T_ACQ[NSPOT[:,i]==0],PUPIL_Y[NSPOT[:,i]==0,i]+i*2*pi,"oC%i"%i,markeredgecolor="k")

        for i in range(4):
            fig.axes[0].plot([TIME_min,TIME_max],np.ones(2)*i*2*pi,"--C%i"%i)
        fig.axes[0].plot([TIME_min,TIME_max],np.ones((2,4))*arange(4)*2*pi,"k:")
        fig.axes[0].plot([TIME_min,TIME_max],np.ones((2,5))*arange(5)*2*pi-pi,"k")
        fig.axes[0].set_xlim(xylims)
        fig.axes[0].set_xlabel("Time (seconds)")
        fig.axes[0].set_ylabel("pixels")
        fig.axes[0].set_title(fig.get_label())
        # fig.tight_layout()


    ############################
    #%% Metrology
    ############################

    fig=figure("Astigmatism at Telescope",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    metrology_plot(fig,oi_met["TIME"],oi_met["astRes"])
    # fig.tight_layout()

    fig=figure("Separation X at Telescope",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    metrology_plot(fig,oi_met["TIME"],oi_met["sepXRes"])
    # fig.tight_layout()

    fig=figure("Separation Y at Telescope",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    metrology_plot(fig,oi_met["TIME"],oi_met["sepYRes"])
    # fig.tight_layout()

    for tel in range(4):
        fig=figure("4 diodes of Telescope %i (dark=mean value)"%(tel+1),clear=True,figsize=(16,4.8))
        figure_list+=[fig.get_label()]
        metrology_plot(fig,oi_met["TIME"],oi_met["PHASE_TELFC_CORR"],tel=tel)
        for TIME,OPD_TELFC_MCORR in zip(oi_met["TIME"],oi_met["OPD_TELFC_MCORR"]):
            plot(TIME,OPD_TELFC_MCORR[:,tel]*2*pi/Lambda_laser+8*pi,"k")
        # fig.tight_layout()

    fig=figure("Phase at Fiber Coupler (data is shifted if corrected for a metrology jump)",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    metrology_plot(fig,oi_met["TIME"],oi_met["PHASE_FC_DRS"],h_v=h_v)
    # fig.tight_layout()

    #%%
    ############################
    #%% FDDL
    ############################

    fig, axs = plt.subplots(4,num="FDDL trajectory from sensor gauge",clear=True,figsize=(16,4.8))
    fig.suptitle(fig.get_label())
    figure_list+=[fig.get_label()]
    fig.subplots_adjust(hspace=0, wspace=0)
    mean_fddl_ft=np.array([o.mean(axis=0) for o in oi_flux["FDDL_FT"]]).mean(axis=0)
    mean_fddl_sc=np.array([o.mean(axis=0) for o in oi_flux["FDDL_SC"]]).mean(axis=0)
    mean_phase_ft=np.array([o.mean(axis=0) for o in oi_flux["PHASE_MET_FCFT"]]).mean(axis=0)
    mean_phase_sc=np.array([o.mean(axis=0) for o in oi_flux["PHASE_MET_FCSC"]]).mean(axis=0)
    mean_fddl_ft=np.array([o.mean(axis=0) for o in oi_flux["FDDL_FT"]])[0]
    mean_fddl_sc=np.array([o.mean(axis=0) for o in oi_flux["FDDL_SC"]])[0]
    mean_phase_ft=np.array([o.mean(axis=0) for o in oi_flux["PHASE_MET_FCFT"]])[0]
    mean_phase_sc=np.array([o.mean(axis=0) for o in oi_flux["PHASE_MET_FCSC"]])[0]

    for f in range(Nf):
        fddl_ft=(oi_flux["FDDL_FT"][f]-mean_fddl_ft)
        fddl_sc=(oi_flux["FDDL_SC"][f]-mean_fddl_sc)
        for i in range(4):
            axs[i].plot(oi_flux["TIME"][f],1e6*fddl_ft[:,i],"C%i"%i)
            axs[i].plot(oi_flux["TIME"][f],1e6*fddl_sc[:,i],"C%i"%(i+4))
    axs[3].set_xlabel("Time (seconds)")
    axs[1].set_ylabel("Length (microns)")
    for i in range(3):
        axs[i].set_xticks([])

    #%%
    fig, axs = plt.subplots(4,num="Difference between metrology and sensor gauge signal (Science AND fringe tracker)",clear=True,figsize=(16,4.8))
    fig.suptitle(fig.get_label())
    figure_list+=[fig.get_label()]
    fig.subplots_adjust(hspace=0, wspace=0)
    for f in range(Nf):
        fddl_ft=(oi_flux["FDDL_FT"][f]-mean_fddl_ft)
        fddl_sc=(oi_flux["FDDL_SC"][f]-mean_fddl_sc)
        phase_sc=(oi_flux["PHASE_MET_FCSC"][f]-mean_phase_sc)*Lambda_laser/(2*pi)
        phase_ft=(oi_flux["PHASE_MET_FCFT"][f]-mean_phase_ft)*Lambda_laser/(2*pi)
        for i in range(4):
            axs[i].plot(oi_flux["TIME"][f],1e6*(fddl_ft+phase_ft)[:,i],"C%i"%i,label="PHASE_MET_FCFT")
            axs[i].plot(oi_flux["TIME"][f],1e6*(fddl_sc+phase_sc)[:,i],"C%i"%(i+4),label="PHASE_MET_FCSC")
            if f==0:
                axs[i].legend()
    axs[3].set_xlabel("Time (seconds)")
    axs[1].set_ylabel("Length (microns)")
    for i in range(3):
        axs[i].set_xticks([])


    #%%
    fig, axs = plt.subplots(4,num="Difference between metrology and sensor gauge signal (Science-Fringe tracker)",clear=True,figsize=(16,4.8))
    fig.suptitle(fig.get_label())
    figure_list+=[fig.get_label()]
    fig.subplots_adjust(hspace=0, wspace=0)
    for f in range(Nf):
        fddl_ft=(oi_flux["FDDL_FT"][f]-mean_fddl_ft)
        fddl_sc=(oi_flux["FDDL_SC"][f]-mean_fddl_sc)
        phase_sc=(oi_flux["PHASE_MET_FCSC"][f]-mean_phase_sc)*Lambda_laser/(2*pi)
        phase_ft=(oi_flux["PHASE_MET_FCFT"][f]-mean_phase_ft)*Lambda_laser/(2*pi)
        for i in range(4):
            axs[i].plot(oi_flux["TIME"][f],1e6*(fddl_ft+phase_ft-fddl_sc-phase_sc)[:,i],"C%i"%i)
    axs[3].set_xlabel("Time (seconds)")
    axs[1].set_ylabel("Length (microns)")
    for i in range(3):
        axs[i].set_xticks([])
    for i in range(4):
        axs[i].yaxis.set_major_locator(MultipleLocator(1.98))
        axs[i].grid(axis='y')

    ############################
    #%% Visibilities on-axis
    ############################

    if sum(h_v['sep']<10) > .5:
        fig, axs = plt.subplots(2,num="Phase and Group delay",clear=True,figsize=(16,4.8))
        fig.subplots_adjust(hspace=0, wspace=0)
        fig.suptitle(fig.get_label())
        figure_list+=[fig.get_label()]

        visC=[]
        Tvis=[]
        Nwave=0

        for f in range(Nf):
            if h_v['sep'][f]<10:
                visCSC=oi_vis["VISDATA"][f].copy()
                visCFT=exp(1j*oi_vis["PHASE_REF"][f])
                visCMET=exp(-1j*oi_vis["phaseMet"][f])
                visC+=[(visCSC*visCFT*visCMET).mean(axis=0)]
                Tvis+=[oi_vis["TIME"][f].mean(axis=0)]
            
        try:
            visC=np.array(visC)
            Tvis=np.array(Tvis)
        except:
            visC=visC[0][None]
            Tvis=Tvis[0][None]


        visC*=exp(-1j*np.angle(visC.mean(axis=0)))

        PD=visC.sum(axis=2)
        GD=(visC[:,:,1:]*conj(visC[:,:,:-1])).sum(axis=2)

        bname=np.array(["12","13","14","23","24","34"])

        for b in range(6):
            axs[0].plot(Tvis,np.angle(PD[:,b])*180/pi,'o-',label="baseline "+bname[b])
            axs[1].plot(Tvis,np.angle(GD[:,b])*180/pi,'o-',label="baseline "+bname[b])

        axs[0].plot([Tvis.min(),Tvis.max()],[[0,90,-90],[0,90,-90]],"k:")
        L=[0,0.65,-0.65,1.3,-1.3,0.65/2,-0.65/2,1.3*3/4,-1.3*3/4]
        axs[1].plot([Tvis.min(),Tvis.max()],[L,L],"k:")
        axs[1].legend()
        axs[0].set_xlim(xylims)
        axs[1].set_xlim(xylims)
        axs[1].set_ylabel("Group delay (rad)")
        axs[0].set_ylabel("Phase delay (rad)")
        axs[1].set_xlabel("Time (note, if baseline 13 is one line below zero, add 2pi to met on tel 3 (or -2pi on tel1) ) ")

    #%%

    fig=figure("VisData FT Amplitude (per pixel, per seconds)",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    for TIME,VISDATA_FT,DIT in zip(oi_vis["TIME"],oi_vis["VISDATA_FT"],h_v["DIT"]):
        for i in range(6):
            plot(TIME,abs(VISDATA_FT[:,i]).mean(axis=1)/DIT,"-C%i"%i)
    ylim(bottom=0)
    xlabel("Time (seconds)")
    fig.axes[0].set_title(fig.get_label())

    fig=figure("VisData SC Amplitude (per pixel, per DIT)",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    for TIME,VISDATA_SC in zip(oi_vis["TIME"],oi_vis["VISDATA"]):
        for i in range(6):
            plot(TIME,abs(VISDATA_SC[:,i]).mean(axis=1),"-C%i"%i)
    ylim(bottom=0)
    xlabel("Time (seconds)")
    fig.axes[0].set_title(fig.get_label())

    fig=figure("Rejection Flags",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    for TIME,REJECTION_FLAG in zip(oi_vis["TIME"],oi_vis["REJECTION_FLAG"]):
        for i in range(6):
            plot(TIME,REJECTION_FLAG[:,i],"-oC%i"%i,markeredgecolor="k")
    plot([TIME_min,TIME_max],np.ones((2))*1,"red",label="FT tracking not good")
    ylim(bottom=-0.1)
    plot([TIME_min,TIME_max],np.ones((2))*2,"magenta",label="Vfactor too low")
    plot([TIME_min,TIME_max],np.ones((2))*4,"cyan",label="OPD_PUPIL too large")
    plot([TIME_min,TIME_max],np.ones((2))*8,"y",label="OPD_PUPIL STDEV")
    legend()
    xlabel("Time (seconds)")
    fig.axes[0].set_title(fig.get_label())

    fig=figure("Arc Curvature of FT data (uncircled data points are flagged as bad)",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    for TIME,arcLength,flag in zip(oi_vis["TIME"],oi_vis["arcLength"],oi_vis["flag"]):
        Good=flag.mean(axis=2)<.5
        for i in range(6):
            plot(TIME,arcLength[:,i],"oC%i"%i)
            plot(TIME[Good[:,i]],arcLength[Good[:,i],i],"-oC%i"%i,markeredgecolor="k")
    plot([TIME_min,TIME_max],np.ones((2))*5,"red",label="poly arc Lenght too long")
    ylim(bottom=-0.1)
    legend()
    xlabel("Time (seconds)")
    fig.axes[0].set_title(fig.get_label())

    ############################
    #%% Flux
    ############################

    fig=figure("SC Flux (per pixel, per DIT)",clear=True,figsize=(16,4.8))
    figure_list+=[fig.get_label()]
    for TIME,FLUX_SC in zip(oi_flux["TIME"],oi_flux["FLUX"]):
        for i in range(4):
            plot(TIME,FLUX_SC[:,i].mean(axis=1),"-C%i"%i)
    ylim(bottom=0)
    plot([TIME_min,TIME_max],np.ones((2))*100000,"red")
    plt.text(TIME_min,100000,"saturation",color="red")
    plot([TIME_min,TIME_max],np.ones((2))*5000,"blue")
    plt.text(TIME_min,5000,"good flux",color="blue")
    xlabel("Time (seconds)")
    fig.axes[0].set_title(fig.get_label())
    # fig.tight_layout()


    #%%
    string_total=""
    for i,e in enumerate(h_v["extension"]):
        if (e==11)|(e==10):
            string_total+=h_v['NAME'][i]+":  "
            string_total+=h_v["DATE"][i]+"      "
            string_total+=str(h_v["NDIT"][i])+" "
            string_total+="%8.2f"%(h_v["DIT"][i])+"s    "
            string_total+=h_v["RES"][i]+"   "
            string_total+=h_v["POLA"][i]+"   "
            string_total+=h_v["AXIS"][i]+"   "
            string_total+=str(h_v["SXY"][i])
            string_total+="\n"
    for i,e in enumerate(h_v["extension"]):
        if (e==11)|(e==10):
            string_total+=h_v['verification'][i]
            string_total+="\n"

    fig=figure("intro",figsize=(9,2+len(files)*.25),clear=True)
    figure_list=[fig.get_label()]+figure_list
    plt.figtext(.01,.95,string_total,va='top',wrap=1,fontsize=8)


    print(name_file)
    with PdfPages(name_file) as pdf:
        for fi in figure_list:
            fig=figure(fi)
            pdf.savefig()

    print("Done after %.3fs"%(time.time()-start_time))

# %%
