
def get_files(d,sof):
    text=open(d+sof,'r').readlines()
    files={}
    flat=0
    datas=0
    datad=0
    datadi=0
    for line in text:
        if len(line.split()) == 2:
            if line.split()[1] == 'DARK_RAW':
                files[line.split()[1]]=d+line.split()[0]
            if line.split()[1] == 'FLAT_RAW':
                flat+=1
                files[line.split()[1]+str(flat)]=d+line.split()[0]
            if line.split()[1] == 'WAVE':
                files[line.split()[1]]=d+line.split()[0]
            if line.split()[1] == 'DARK':
                files[line.split()[1]]=d+line.split()[0]
            if line.split()[1] == 'FLAT':
                files[line.split()[1]]=d+line.split()[0]
            if line.split()[1] == 'P2VM':
                files[line.split()[1]]=d+line.split()[0]
            if line.split()[1] == 'BAD':
                files[line.split()[1]]=d+line.split()[0]
            if line.split()[1] == 'DUAL_SCIENCE_RAW':
                dataf+=1
                files[line.split()[1]+str(dataf)]=d+line.split()[0]
                files["DUAL_SCIENCE_PREPROC"+str(dataf)]=d+"gravi_preproc_"+line.split()[0]
                files["DUAL_SCIENCE_REDUCED"+str(dataf)]=d+"dual_reduced_"+line.split()[0]
            if line.split()[1] == 'SINGLE_SCIENCE_RAW':
                datas+=1
                files[line.split()[1]+str(datas)]=d+line.split()[0]
                files["SINGLE_SCIENCE_PREPROC"+str(datas)]=d+"gravi_preproc_"+line.split()[0]
                files["SINGLE_SCIENCE_REDUCED"+str(datas)]=d+"dual_reduced_"+line.split()[0]
            if line.split()[1] == 'DISP_RAW':
                datadi+=1
                files[line.split()[1]+str(datadi)]=d+line.split()[0]
                files["DISP_RAW_PREPROC"+str(datadi)]=d+"gravi_preproc_"+line.split()[0]
                files["DISP_RAW_REDUCED"+str(datadi)]=d+"dual_reduced_"+line.split()[0]
    return files
        
        
def gauss(x, *p):
    A, mu, sigma = p
    return A*exp(-(x-mu)**2/(2.*sigma**2))
        
        
def get_visdata(file_i):    
    DIT=fits.getheader(file_i)['HIERARCH ESO DET2 SEQ1 DIT']*1e6
    PER=fits.getheader(file_i)['HIERARCH ESO INS TIM1 PERIOD']*1e6
    FDDL_FT=fits.getdata(file_i,2).field('FT_POS')
    FDDL_SC=fits.getdata(file_i,2).field('SC_POS')
    FMet =  fits.getdata(file_i,8).field('PHI').mean(axis=0)
            
    MJD=fits.getheader(file_i,0)['MJD-OBS']
    print((fits.getheader(file_i,0)['DATE-OBS']))
    # Je lis la metrologie
    met = dot(M_matrix,fits.getdata(file_i,8).field('PHI').T).T
    metT= fits.getdata(file_i,8).field('TIME')
    
    # Puis les longueurs d'onde
    WS1  = fits.getdata(file_i,4).field('EFF_WAVE')
    WF1  = fits.getdata(file_i,6).field('EFF_WAVE')
    WS2  = fits.getdata(file_i,5).field('EFF_WAVE')
    WF2  = fits.getdata(file_i,7).field('EFF_WAVE')
    
    # Enfin les "time stamps"
    SC1T = fits.getdata(file_i,9).field('TIME')[::6]+PER
    FT1T = fits.getdata(file_i,15).field('TIME')[::6]
    SC2T = fits.getdata(file_i,12).field('TIME')[::6]+PER
    FT2T = fits.getdata(file_i,18).field('TIME')[::6]
    # Puis les donnees SC (pola1 &2)
    Nw=fits.getdata(file_i,9).field('VISDATA').shape[1]
    S1  = fits.getdata(file_i,9).field('VISDATA').reshape(len(SC1T),6,Nw)
    S2  = fits.getdata(file_i,12).field('VISDATA').reshape(len(SC1T),6,Nw)
    # Puis les donnees FT (pola1 &2)
    FT1  = fits.getdata(file_i,15).field('VISDATA').reshape(len(FT1T),6,5)
    FT2  = fits.getdata(file_i,18).field('VISDATA').reshape(len(FT1T),6,5)
    
    # Je selectionne l'echantillonage temporel    
    T=SC1T[(SC1T>FT1T.min())&(SC1T<FT1T.max())]
    
    # Et FT (ca, on en a besoin)
    F1=array([FT1[(FT1T>T[t]-PER/2)&(FT1T<T[t]+PER/2),:,:].mean(axis=0) for t in range(len(T))])
    F2=array([FT2[(FT1T>T[t]-PER/2)&(FT1T<T[t]+PER/2),:,:].mean(axis=0) for t in range(len(T))])
    
    F1W=array([[angle(F1[i,j].mean())+np.poly1d(np.polyfit(WF1, angle(F1[i,j]*conj(F1[i,j].mean())), 2))(WS1) for j in range(6)] for i in range(len(T))])
    F2W=array([[angle(F2[i,j].mean())+np.poly1d(np.polyfit(WF2, angle(F2[i,j]*conj(F2[i,j].mean())), 2))(WS2) for j in range(6)] for i in range(len(T))])
    
    # Et enfin la metrologie
    M=array([met[(metT>T[t]-DIT/2)&(metT<T[t]+DIT/2),:].mean(axis=0) for t in range(len(T))])
    
    # La, les choses serieuses commencent. C'est un peu complique a expliquer. Mais il suffit de faire cela:
    
    #OPD_Mean est l'OPD moyenne sur le temps qui correspond a la metrologie.
    Lambda_met=1.905e-6 
    OPD_Mean=M.mean(axis=0)*Lambda_met
    
    #OPD_Met est l'OPD qui correspond a la metrologie.
    OPD_Met=(M-M.mean(axis=0))*Lambda_met
    
    e=load("disp_param.npz")
    D=e['arr_0']
    
    # VISDATA est "l'ancien" VISDATA
    phi1= - F1W + (OPD_Met[:,:,None])/WS1
    phi2= - F2W + (OPD_Met[:,:,None])/WS2
    VISDATA1 = (S1*exp(1j* phi1)).mean(axis=0)  
    VISDATA2 = (S2*exp(1j* phi2)).mean(axis=0)  
    VISDATA1_T = (S1*exp(1j* (-F1W+M[:,:,None]*Lambda_met/WS1))).mean(axis=0)  
    VISDATA2_T = (S2*exp(1j* (-F2W+M[:,:,None]*Lambda_met/WS2))).mean(axis=0)  
    
    GD1=(VISDATA1[:,1:]*conj(VISDATA1[:,:-1])).sum(axis=1)
    GD2=(VISDATA2[:,1:]*conj(VISDATA2[:,:-1])).sum(axis=1)
    
    #OPD_GD est l'OPD qui correspond a la pente de phase.
    OPD_GD1=(angle(GD1)/diff(1/WS1).mean())[:,None]/WS1
    OPD_GD2=(angle(GD2)/diff(1/WS2).mean())[:,None]/WS2
    
    VISDATA1_2=(VISDATA1*exp(-1j*OPD_GD1))
    VISDATA1_2_Mean=VISDATA1_2.mean(axis=1) 
    
    
    VISDATA2_2=(VISDATA2*exp(-1j*OPD_GD2))
    VISDATA2_2_Mean=VISDATA2_2.mean(axis=1) 
    
    # VISDATA2_Mean correspond a une moyenne generale sur la phase
    
    VISDATA1_3=VISDATA1_2*exp(-1j* angle(VISDATA1_2_Mean)[:,None])
    VISDATA2_3=VISDATA2_2*exp(-1j* angle(VISDATA2_2_Mean)[:,None])
    
    Vis_Phi1=angle(VISDATA1_3)+angle(VISDATA1_2_Mean)[:,None]+OPD_Mean[:,None]/WS1 + OPD_GD1
    Vis_Phi2=angle(VISDATA2_3)+angle(VISDATA2_2_Mean)[:,None]+OPD_Mean[:,None]/WS2 + OPD_GD2

    return (VISDATA1_T,VISDATA2_T,Vis_Phi1,Vis_Phi2,array([FDDL_FT.mean(axis=0),FDDL_SC.mean(axis=0)]),FMet,WS1,WS2)
    
def fit_phase(params, xdata, ydata):
    error=(ydata-params[0]-params[1]*xdata)%(2*pi)
    error[error> pi]-=2*pi
    error[error<-pi]+=2*pi
    return error
    
from numpy import *   
import cmath
from matplotlib.pyplot import *    
import pyfits as fits
from scipy import signal
from scipy.interpolate import interp1d
from scipy.optimize import leastsq
from scipy.signal import savgol_filter
from numpy.polynomial.chebyshev import *

from numpy import *   
from matplotlib.pyplot import *    
import pyfits as fits
import scipy.interpolate as inter

M_matrix = -1*array([-1.,1.,0.0,0.0,-1.,0.0,1.,0.0,-1.,0.0,0.0,1.,0.0,-1.,1.,0.0,0.0,-1.,0.0,1.,0.0,0.0,-1.,1.]);
M_matrix = M_matrix.reshape((6,4))

#
#d = "../DATA/Dispersion/"
#sof = "disp_data.sof"
#
#files=get_files(d,sof)
#
#V1=[]
#V2=[]
#P1=[]
#P2=[]
#FDDL=[]
#MET=[]
#W1=[]
#W2=[]

#for i in range(100):
#    if "DISP_RAW_REDUCED"+str(i) in files:
#        a,b,c,d,e,f,g,h=get_visdata(files["DISP_RAW_REDUCED"+str(i)])
#        V1.append(a)
#        V2.append(b)
#        P1.append(c)
#        P2.append(d)
#        FDDL.append(e)
#        MET.append(f)
#        W1.append(g)
#        W2.append(h)
#        
#V1=array(V1)
#V2=array(V2)
#P1=array(P1)
#P2=array(P2)
#FDDL=array(FDDL)
#MET=array(MET)
#W1=array(W1)
#W2=array(W2)
#
#savez("disp_data2",V1,V2,P1,P2,FDDL,MET,W1,W2)
e=load("disp_data2.npz")
V1,V2,P1,P2,FDDL,MET,W1,W2= e['arr_0'],e['arr_1'],e['arr_2'],e['arr_3'],e['arr_4'],e['arr_5'],e['arr_6'],e['arr_7']

MET*=1.905e-6/(2*pi)
FDDL_pos=[]
FDDL_met=[]
for i in range(4):
    M=MET[:,i]
    F=FDDL[:,:,i]
    F2=array([F[:,0]*0+1,F[:,0],F[:,0]**2,F[:,1],F[:,1]**2])
    a=dot(linalg.pinv(F2.T),M)
    QC_lin=a;
    a[0]=0;
    FDDL_met.append(dot(F2.T,a));
    a[3:]*=-1;
    FDDL_pos.append(dot(F2.T,a));

FDDL_pos=array(FDDL_pos).T
FDDL_met=array(FDDL_met).T
    
M=dot(M_matrix,MET.T).T
F=dot(M_matrix,FDDL_pos.T).T


W=W1[1]*1e6
Nw=len(W)
GD1=angle((V1[:,:,1:]*conj(V1[:,:,:Nw-1])).mean(axis=2))
GD2=angle((V2[:,:,1:]*conj(V2[:,:,:Nw-1])).mean(axis=2))
V1c=V1*exp(-1j*arange(Nw)*median(GD1,axis=0)[:,None])
V2c=V2*exp(-1j*arange(Nw)*median(GD2,axis=0)[:,None])

a=linspace(-6e3,3e3,500)
A=[]
for i in range(6):
    Ms=sort(M[:,i])
    aM=argsort(M[:,i])
    B=[]
    for l in range(Nw):
        Vs=V1[:,i,l][aM]
        amp=abs((Vs*exp(1j*a[:,None]*Ms)).sum(axis=1))
        B.append(a[amp.argmax()])
    B=array(B)
    A.append(B)
A=array(A)

GD1=angle((V1[:,:,1:]*conj(V1[:,:,:Nw-1])).mean(axis=2))
Slopes=M[:,:,None]*A-arange(Nw)*median(GD1,axis=0)[:,None]
V1b=V1*exp(1j*Slopes)
K=unwrap(angle(V1b.mean(axis=0)))
P1c=angle(V1*exp(1j*(Slopes-K)))-(Slopes-K)

x=polyfit(1/W,P1c.mean(axis=0).T,1)
for i in range(6):
    P1c[:,i]-=poly1d(x[:,i])(1/W)

B=[]
C=[]
D=[]
for i in range(6):
    F2=array([F[:,i]*0+1,F[:,i],M[:,i]])
    r=dot(linalg.pinv(F2.T),P1c[:,i])
    a=Chebyshev(chebfit(1/W,r[0],30))(1/W)
    b=Chebyshev(chebfit(1/W,r[1],10))(1/W)
    c=Chebyshev(chebfit(1/W,r[2],30))(1/W)
    D.append(array([a,b,c]))
    B.append(r)
D=array(D)
B=array(B)

P2c=P1c
for i in range(6):
    F2=array([F[:,i]*0+1,F[:,i],M[:,i]])
    P2c[:,i]=dot(F2.T,D[i])

#V2b=V1*exp(-1j*P2c)
#savez("disp_param",D)

figure(0)
plot(W,D[:,0].T)
xlabel("wavelength")
ylabel("Dispersion -offset")
figure(1)
plot(W,D[:,1].T)
xlabel("wavelength")
ylabel("Dispersion - met factor")
figure(2)
plot(W,D[:,2].T)
xlabel("wavelength")
ylabel("Dispersion - fddl factor")
