# -*- coding: utf-8 -*-
"""
Created on Wed Sep  9 16:07:58 2015

@author: kervella
"""
import matplotlib.pyplot as plt
import numpy as np
import gravi_rawclass
import gravi_darkclass

filename = 'GRAVITY.2015-09-09T18-18-32-blink' # MR
darkname = 'master_dark.2015-09-06T13-45-52' # MR dark

    # MR file:
    # Delta offset for correction of SC-FT time scale = -6.6 microseconds (add to SC offset) 
    # *** Delay of start of DIT of FT line N ((between 0 and 47) relative to center of DIT of all FT lines:
    #           Dt(FT_N - FT_center) = 0.001650 x [N - 23.5] milliseconds
    # The FT readout order is not unidirectional, and this scale may be biased by a few µs for a given output.

    # slope: +0.400 milliseconds/line    ZP = +159.765 milliseconds
    # *** Delay of start of DIT of SC line N relative to center of all FT line DIT :
    #       Dt(SC-FT) = 0.400 x N + 9.759 milliseconds
    # *** Delay of start of DIT of SC OUTPUT N relative to center of all FT line DIT :
    #       Dt(SC OUTPUT) = 2.801 x N + 10.959 milliseconds


# filename = 'GRAVITY.2015-09-11T00-43-43-blink' # HR
# darkname = 'master_dark.2015-08-15T08-15-43' # HR dark

    # HR file:
    # slope: +0.411 milliseconds/line    ZP = +158.811 milliseconds
    # *** Delay of start of DIT of SC line N relative to center of all FT line DIT :
    #           Dt(SC-FT) = 0.411 x N + 8.805 milliseconds

raw = gravi_rawclass.Rawdata(filename+'.fits')
dark = gravi_darkclass.Darkdata(darkname+'.fits')

n_outputs = 48
n_wave_ft = 5

dt_sc = raw.time_sc[1]-raw.time_sc[0] # inter-frame for SC
dt_ft = raw.time_ft[1]-raw.time_ft[0] # inter-frame for FT

plt.close('all')

#==============================================================================
# Fringe tracker photometry
#==============================================================================

if True:
    print("Fringe Tracker time offset computation")
    #print raw.data_ft.shape #(time,channel,wavelength)
    
    
    # dark and bright level normalization
    for output in range(0,n_outputs):
        for wave in range(0,n_wave_ft):
            # determine the time segments with dark and bright values
            overall_ft = raw.data_ft[:,output,wave]
            overall_mean = np.mean(overall_ft)
            overall_ft /= overall_mean
            bright_level = np.nanmedian(overall_ft[np.where(overall_ft > 1)])
            dark_level = np.nanmedian(overall_ft[np.where(overall_ft < 1)])
            noise_level = np.nanstd(overall_ft[np.where(overall_ft < 1)])
            
            raw.data_ft[:,output,wave] -= dark_level
            raw.data_ft[:,output,wave] /= (bright_level-dark_level)
    
    #plt.figure()
    #plt.plot(raw.data_ft[:,0,2])
    
    # Average signal over wavelengths (to correlate the different wavelengths)
    mean_ft = np.mean(raw.data_ft[:,:,:],axis=2) # time, output
    
    # Average signal over channels and wavelengths (to correlate the different channels)
    open_channels = np.array(list(range(0,8))+list(range(16,24))+list(range(32,40)))
    center_open_channels = np.mean(open_channels) # effective center of the average signal
    
    grand_mean_ft =  np.mean(mean_ft[:,open_channels],axis=1)
    trigger = np.abs(np.diff(np.sign(grand_mean_ft - 0.5))/2)
    n = 30 # half-width of the window around the flux jumps
    mask_ft = np.array([max(trigger[max(i-n, 0):min(i+n, len(trigger))])==1 for i in range(len(trigger))])
    
    time_ft = np.array(list(range(0,sum(mask_ft))))*dt_ft # Compressed time scale
    mean_ft_masked = np.zeros((sum(mask_ft),n_outputs))
    for output in open_channels:
        mean_ft_masked[:,output] = mean_ft[mask_ft,output]
    
    grand_mean_ft_masked = grand_mean_ft[mask_ft]
    
    range_ft = 0.000100 # +/- range in seconds
    nsteps = 200
    delays = np.linspace(-range_ft,range_ft,nsteps)
    delays_ft = np.zeros(n_outputs)
    
    residuals = np.zeros((n_outputs,nsteps))
    print("Computation of the time offsets for each output of the FT")
    for output in open_channels:
        for i,delay in enumerate(delays):
            time_ft_shift = time_ft + delay
            gmfm_interp =np.interp(time_ft_shift, time_ft, grand_mean_ft_masked)
            residuals[output,i]=np.std(gmfm_interp-mean_ft_masked[:,output])
        residuals[output,:] -= np.min(residuals[output,:])
    #            if output == 2:
    #                plt.figure()
    #                plt.plot(delays, residuals[output,:])
        delays_ft[output] = delays[np.argmin(residuals[output,:])] # in seconds
    
    [slope_ft,zeropoint_ft] = np.polyfit(open_channels,delays_ft[open_channels],1)
    
    delays_ft_int = np.zeros(n_outputs)
    center_outputs = np.mean(list(range(0,n_outputs))) # center of all outputs
    for output in range(0,n_outputs):
        delays_ft_int[output] = slope_ft*(output-center_outputs)
        print("FT output %02i interpolated offset wrt to center (output %.1f) = %+.1f microseconds" % (output, center_outputs, delays_ft_int[output]*10**6))
        
    delta_offset_scft = slope_ft*(center_open_channels-center_outputs)
    print(" Delta offset for correction of SC-FT time scale = %+.1f microseconds (add to SC offset) " \
        % (delta_offset_scft*10**6))
    
    plt.figure(figsize=(10,10))
    plt.title("FT delay map per output (microsec) = %.3f x [output]%+.3f" % (slope_ft*10**6, zeropoint_ft*10**6))
    x, y = np.meshgrid(delays*10**6, np.arange(n_outputs))
    plt.pcolor(x,y,residuals)        
    plt.xlabel("Delay wrt average of illuminated (microseconds)")
    plt.ylabel("FT output")
    plt.plot(delays_ft_int*10**6,list(range(0,n_outputs)),color='red',lw=3)

    print(" *** Delay of start of DIT of FT line N relative to center of DIT of all FT lines:")
    print("           Dt(FT_N - FT_center) = %.6f x [N - %.1f] milliseconds" % (slope_ft*1000.,center_outputs))
    print("     The FT readout order is not unidirectional, and this scale may be biased by a few µs for a given output.")


    # Skipping one line requires (most probably) 32 ticks of the clock (1 tick = 0.2 µs, f = 5 MHz), or 6.4 µs
    # Reading a full line with the 32 outputs requires 410 µm

#==============================================================================
# Science combiner photometry
#==============================================================================

print("Science combiner time offset computation")

# print raw.data_sc.shape
n_frames = raw.data_sc.shape[0]
n_lines = raw.data_sc.shape[1]
n_columns = raw.data_sc.shape[2]

# Dark frame subtraction
dark_frame = dark.darkimage_sc

#plt.figure()
#plt.imshow(dark_frame,vmin=np.percentile(dark_frame,1),vmax=np.percentile(dark_frame,99),aspect='auto',)

print("Dark frame subtraction")
# subtraction of the dark from each image in the cube
for frame in range(0,n_frames):
    raw.data_sc[frame,:,:] -= dark_frame[:,:]

# identify illuminated pixels and create mask
mean_frame = np.mean(raw.data_sc[:int(n_frames/10),:,:],axis=0)
plt.figure()
plt.imshow(mean_frame,vmin=np.percentile(mean_frame,1),vmax=np.percentile(mean_frame,99),aspect='auto',)
#
dark_level = np.nanpercentile(mean_frame,3) # proxy of the dark level
bright_level = np.nanpercentile(mean_frame,99) # proxy of the bright level (excluding bad pixels)

print("Determination of the illuminated pixels")
#mean_level = np.mean([dark_level,bright_level])
illuminated = np.zeros((n_lines,n_columns),dtype=bool)
illuminated[np.where(mean_frame>0.3*bright_level)]=True
illuminated[np.where(mean_frame<0.3*bright_level)]=False

#illuminated[:,:120]=False # to cut the non-illuminated columns on the left IN HR ONLY
#illuminated[213:310,:]=False # to cut the non-illuminated lines in the middle IN HR ONLY
#illuminated[410:,:]=False # to cut the non-illuminated lines at the bottom IN HR ONLY

plt.figure()
plt.imshow(illuminated,aspect='auto',interpolation='None')
plt.title("Illuminated pixels")

print("Computation of the normalized signal per illuminated pixel")
# dark and bright level normalization
for x in range(0,n_lines):
    for y in range(0,n_columns):
        if illuminated[x,y]==True:
            # determine the time segments with dark and bright values
            overall_sc = raw.data_sc[:,x,y]
            overall_mean = np.mean(overall_sc)
            overall_sc /= overall_mean
            bright_level = np.nanmedian(overall_sc[np.where(overall_sc > 1)])
            dark_level = np.nanmedian(overall_sc[np.where(overall_sc < 1)])
            noise_level = np.nanstd(overall_sc[np.where(overall_sc < 1)])
        
            raw.data_sc[:,x,y] -= dark_level
            raw.data_sc[:,x,y] /= (bright_level-dark_level)
        else: raw.data_sc[:,x,y]=0

print("Determination of the SC delays with respect to FT")
range_ftsc = 0.020 # +/- range in seconds
nsteps = 10 # number of steps on each side
residuals = np.zeros((n_lines,n_columns,2*nsteps+1))
delay_ftsc = np.zeros((n_lines,n_columns))
resamp = 50
subsample = 100

for x in range(0,n_lines):
    center_ftsc = 0.160+0.000400*x # Approximate delay for each line
    delays = np.linspace(center_ftsc-range_ftsc,center_ftsc+range_ftsc,2*nsteps+1)
    for y in range(0,n_columns):
        if illuminated[x,y]==True:
            for i,delay in enumerate(delays):
                time_ft_shift = raw.time_ft[::resamp] - delay # delay is the delay of SC relative to FT
                gmsc_interp = np.interp(time_ft_shift,raw.time_sc[:subsample],raw.data_sc[:subsample,x,y])
                residuals[x,y,i]=np.nanstd(grand_mean_ft[::resamp]-gmsc_interp)
            residuals[x,y,:] -= np.nanmin(residuals[x,y,:])
            delay_ftsc[x,y] = delays[np.nanargmin(residuals[x,y,:])] # in seconds
    print("line x=%i    median delay=%.3f s" %(x, np.nanmedian(delay_ftsc[x,np.where(illuminated[x,:]==True)])))

plt.figure()
plt.imshow(delay_ftsc,vmin=0.17,vmax=0.40,aspect='auto',interpolation='None')
plt.set_cmap('cubehelix')
plt.colorbar()

median_delay_lines = np.zeros(n_lines)
for line in range(0,n_lines):
    if np.sum(illuminated[line,:])>50:
        median_delay_lines[line] = np.nanmedian(delay_ftsc[line,np.where(illuminated[line,:]==True)])

line_ok = np.squeeze(np.array(np.where(median_delay_lines > 0)))
[slope_sc,zeropoint_sc] = np.polyfit(line_ok,median_delay_lines[line_ok],1)
print("slope: %+.3f milliseconds/line    ZP = %+.3f milliseconds" % (slope_sc*1000.,zeropoint_sc*1000.))

plt.figure()
plt.plot(median_delay_lines)

mean_delay_columns = np.zeros(n_columns)
for column in range(0,n_columns):
    if np.sum(illuminated[:,column])>20:
        mean_delay_columns[column] = np.nanmean(delay_ftsc[np.where(illuminated[:,column]==True),column])
plt.figure()
plt.title("Average delay on a line")
plt.plot(mean_delay_columns)

offset_SC_FT = zeropoint_sc - raw.exptime_sc/2.0 + delta_offset_scft # relative to the center of the FT DITs of all lines
print(" *** Delay of start of DIT of SC line N relative to center of all FT line DIT :")
print("           Dt(SC-FT) = %.3f x N + %.3f milliseconds" % (slope_sc*1000., offset_SC_FT*1000.))

print(" *** Delay of start of DIT of SC OUTPUT N relative to center of all FT line DIT :")
print("           Dt(SC OUTPUT) = %.3f x N + %.3f milliseconds" % (7*slope_sc*1000., offset_SC_FT*1000.+3*slope_sc*1000.))

#if False:
#    #==============================================================================
#    # Science combiner photometry
#    #==============================================================================
#    
#    print "Science combiner time offset computation"
#    
#    startx = 53
#    n_wave = 234
#    print raw.data_sc.shape
#    n_frames = raw.data_sc.shape[0]
#    n_lines = raw.data_sc.shape[1]
#    
#    plt.figure()
#    plt.imshow(raw.data_sc[5,:,:])
#    
#    spectrum_sc = np.zeros((n_frames, n_outputs, n_wave)) #[time,output,wavelength 60:320]
#    
#    output = 0
#    line = 3
#    while line < n_lines:
#        for time in range(n_frames): # box integration of spectra over 3 pixels in height
#            spectrum_sc[time,output,:] = np.mean(raw.data_sc[time,(line-1):(line+1),startx:(startx+n_wave)],axis=0)
#        output += 1
#        line += 7
#    
#    # dark part of the sequence
#    dark_sc = np.mean(spectrum_sc[276:287,:,:],axis=0)
#    spectrum_sc -= dark_sc
#    
#    open_channels = np.array(range(0,16)+range(24,32))
#    
#    # bright part of the sequence 
#    bright_sc = np.mean(spectrum_sc[0:15,open_channels,:],axis=0)
#    spectrum_sc[:,open_channels,:] /= bright_sc
#    
#    plt.figure()
#    #for output in range(0,n_frames):
#    plt.imshow(spectrum_sc[5,:,:])
#    
#    # Average signal over wavelengths (to correlate the different wavelengths)
#    mean_sc = np.mean(spectrum_sc,axis=2) # time, output
#    
#    # Average signal over channels and wavelengths (to correlate the different channels)
#    grand_mean_sc =  np.mean(mean_sc[:,open_channels],axis=1)
#    trigger = np.abs(np.diff(np.sign(grand_mean_sc - 0.5))/2)
#    n = 3 # half-width of the window around the flux jumps
#    mask_sc = np.array([max(trigger[max(i-n, 0):min(i+n, len(trigger))])==1 for i in range(len(trigger))])
#    
#    # plt.figure()
#    time_sc = np.array(range(0,sum(mask_sc)))*dt_sc # compression of the time scale for the delay fit
#    mean_sc_masked = np.zeros((sum(mask_sc),n_outputs))
#    for output in open_channels:
#        mean_sc_masked[:,output] = mean_sc[mask_sc,output]
#        # plt.plot(time_sc,mean_sc_masked[:,output])
#    
#    grand_mean_sc_masked = grand_mean_sc[mask_sc]
#    
#    range_sc = 0.150 # +/- range in seconds
#    nsteps = 600
#    delays = np.linspace(-range_sc,range_sc,nsteps)
#    delays_sc = np.zeros(n_outputs,dtype='d')
#    
#    residuals = np.zeros((n_outputs,nsteps))
#    for output in open_channels:
#        for i,delay in enumerate(delays):
#            time_sc_shift = time_sc + delay
#            #sc_int = interp1d(time_sc, grand_mean_sc_masked, kind='linear',bounds_error=False, fill_value=0.0)
#            # gmsm_interp = sc_int(time_sc_shift)
#            gmsm_interp =np.interp(time_sc_shift, time_sc, grand_mean_sc_masked)
#            residuals[output,i]=np.std(gmsm_interp-mean_sc_masked[:,output])
#        residuals[output,:] -= np.min(residuals[output,:])
#    #        if output == 2:
#    #            plt.figure()
#    #            plt.plot(delays, residuals[output,:])
#        delays_sc[output] = delays[np.argmin(residuals[output,:])] # in seconds
#        
#    [slope_sc,zeropoint_sc] = np.polyfit(open_channels,delays_sc[open_channels],1)
#    
#    delays_sc_int = np.zeros(n_outputs)
#    for output in range(0,n_outputs):
#        delays_sc_int[output] = slope_sc*output + zeropoint_sc
#        print "SC output %02i interpolated offset = %+.1f milliseconds" % (output, delays_sc_int[output]*10**3)
#    
#    plt.figure(figsize=(10,10))
#    plt.title('SC delay map per output (millisec) = %.3f x [output]%+.3f' % (slope_sc*10**3,zeropoint_sc*10**3))
#    x, y = np.meshgrid(delays*1000, np.arange(n_outputs))
#    plt.pcolor(x,y,residuals)
#    plt.xlabel("Delay wrt average of illuminated (milliseconds)")
#    plt.ylabel("SC output")
#    plt.plot(delays_sc_int*10**3,range(0,n_outputs),color='red',lw=3)
#    
#    #==============================================================================
#    # Science Combiner delay per spectral pixel with respect to grand average
#    #==============================================================================
#    
#    range_sc = 0.005 # +/- range in seconds
#    nsteps = 50
#    delays_sc_wave = np.zeros((n_outputs,n_wave),dtype='d')
#    
#    # masked non-averaged spectra
#    spectrum_sc_masked = np.zeros((sum(mask_sc),n_outputs,n_wave))
#    for output in open_channels:
#        for wave in range(0,n_wave):
#            spectrum_sc_masked[:,output,wave] = spectrum_sc[mask_sc,output,wave]
#    
#    residuals = np.zeros((n_outputs,n_wave,nsteps))
#    for output in open_channels:
#        delays = np.linspace(delays_sc[output]-range_sc,delays_sc[output]+range_sc,nsteps)
#        for wave in range(0,n_wave):
#            for i,delay in enumerate(delays):
#                time_sc_shift = time_sc + delay
#                gmsm_interp =np.interp(time_sc_shift, time_sc, grand_mean_sc_masked)
#                residuals[output,wave,i]=np.std(gmsm_interp-spectrum_sc_masked[:,output,wave])
#            residuals[output,wave,:] -= np.min(residuals[output,wave,:])
#            delays_sc_wave[output,wave] = delays[np.argmin(residuals[output,wave,:])]
#    
#    mean_wave_delay = np.mean(delays_sc_wave,axis=0)
#    plt.figure(figsize=(10,10))
#    plt.title("Average delay wrt to column 0 (milliseconds)")
#    plt.plot(mean_wave_delay)
#        
#    plt.figure(figsize=(10,10))
#    plt.title("SC delay map per spectral channel (milliseconds)")
#    #    plt.imshow(delays_sc_wave,aspect='auto')
#    x, y = np.meshgrid(np.arange(n_wave), np.arange(n_outputs))
#    plt.pcolor(x,y,delays_sc_wave)
#    plt.colorbar(fraction=0.15,aspect=15,orientation='horizontal',shrink=0.7,pad=0.1)
#    plt.xlabel("SC spectral channel")
#    plt.ylabel("SC output")
#    
#    
#    #==============================================================================
#    # Science Combiner delay with respect to Fringe Tracker
#    #==============================================================================
#    
#    range_ftsc = 0.200 # +/- range in seconds
#    nsteps = 500 # number of steps
#    center_ftsc = 0.200 # center of the range in seconds
#    delays = np.linspace(center_ftsc-range_ftsc,center_ftsc+range_ftsc,nsteps)
#    residuals = np.zeros(nsteps)
#    
#    for i,delay in enumerate(delays):
#        time_ft_shift = raw.time_ft - delay # delay is the delay of SC relative to FT
#        gmsc_interp =np.interp(time_ft_shift, raw.time_sc, grand_mean_sc)
#        residuals[i]=np.std(grand_mean_ft-gmsc_interp)
#    residuals -= np.min(residuals)
#    plt.figure()
#    plt.plot(delays, residuals)
#    delay_ftsc = delays[np.argmin(residuals[:])] # in seconds
#    print " *** Delay of start of DIT of SC output 0 relative to start of DIT of SC output 0   = %+.5f seconds" % \
#        (delay_ftsc+zeropoint_sc-zeropoint_ft-raw.exptime_sc/2.0)
#    # one SC exposure is missed at the beginning apparently
#    
#    plt.figure(figsize=(10,10))
#    plt.plot(raw.time_ft,grand_mean_ft)
#    plt.plot(raw.time_sc+delay_ftsc,grand_mean_sc)
