#! /usr/bin/env python3
# -*- coding: iso-8859-15 -*-
"""
Created on Mon May 25 13:51:41 2015

@author: kervella
"""

import matplotlib as mpl
mpl.use('Agg') # necessary to be able to generate the reports even without X server
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec

import sys
import numpy as np
#import pdb
#import warnings

try:
    from scipy.signal import welch
    welchpresent = True
except ImportError as e:
    welchpresent = False
    pass # module doesn't exist, deal with it.
 
#==============================================================================
# Preparation of the report using reportlab
#==============================================================================
def produce_vibrations_report(p2vmreduced,filename):

    ntel = 4; nbase = 6;

    # axis limits
    xlims = [0,250] # Hz
    xlimsopd = [0,250] # Hz
    ylimsopd  = [1E-0,1E+3] # nm rms
    ylimsphot = [1E-3,1E-1] # relative flux rms
    ylimsmet  = [5E-1,1E+2] # nm rms
    
    plt.close('all')
    fig = plt.figure(figsize=(10,10),dpi=150)
    matplotlib.rcParams.update({'font.size': 6})
    gs = gridspec.GridSpec(5, 4)
    gs.update(wspace=0.2, hspace=0.2) # set the spacing between axes. 
    plt.suptitle("Periodograms for %s"%(filename),fontsize=12)
    plt.subplots_adjust(top=0.95)

    # OPD PSDs
    T0T1opd = plt.subplot2grid((5,4), (0,0), colspan=2)
    T0T2opd = plt.subplot2grid((5,4), (0,2), colspan=2)
    T0T3opd = plt.subplot2grid((5,4), (1,0), colspan=2)
    T1T2opd = plt.subplot2grid((5,4), (1,2), colspan=2)
    T1T3opd = plt.subplot2grid((5,4), (2,0), colspan=2)
    T2T3opd = plt.subplot2grid((5,4), (2,2), colspan=2)
    T0T1opd.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T0T2opd.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T0T3opd.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T1T2opd.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T1T3opd.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T2T3opd.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T0T1opd.grid();T0T2opd.grid();T0T3opd.grid();
    T1T2opd.grid();T1T3opd.grid();T2T3opd.grid();
    
    # Photometry PSDs
    T0phot = plt.subplot2grid((5,4), (3,0))
    T1phot = plt.subplot2grid((5,4), (3,1))
    T2phot = plt.subplot2grid((5,4), (3,2))
    T3phot = plt.subplot2grid((5,4), (3,3))
    T0phot.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T1phot.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T2phot.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T3phot.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T0phot.grid();T1phot.grid();T2phot.grid();T3phot.grid();
    
    # Metrology PSDs
    T0met = plt.subplot2grid((5,4), (4,0))
    T1met = plt.subplot2grid((5,4), (4,1))
    T2met = plt.subplot2grid((5,4), (4,2))
    T3met = plt.subplot2grid((5,4), (4,3))
    T0met.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T1met.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T2met.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T3met.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    T0met.grid();T1met.grid();T2met.grid();T3met.grid();
    #T0met.patch.set_facecolor('red'); T0T1met.patch.set_alpha(0.5)


    # log scale
#    T0T1opd.set_xscale('log');T0T2opd.set_xscale('log');T0T3opd.set_xscale('log');
#    T1T2opd.set_xscale('log');T1T3opd.set_xscale('log');T2T3opd.set_xscale('log');
    T0T1opd.set_yscale('log');T0T2opd.set_yscale('log');T0T3opd.set_yscale('log');
    T1T2opd.set_yscale('log');T1T3opd.set_yscale('log');T2T3opd.set_yscale('log');
    T0phot.set_yscale('log');T1phot.set_yscale('log');
    T2phot.set_yscale('log');T3phot.set_yscale('log');
    T0met.set_yscale('log');T1met.set_yscale('log');
    T2met.set_yscale('log');T3met.set_yscale('log');
    
        

    #==============================================================================
    # Power spectral density of the Kalman OPD
    #==============================================================================

    if hasattr(p2vmreduced,'opdc_kalman_piezo'):
        # Computation of the Kalman vibration OPD = the incoming OPD, i.e. the "true" (uncorrected) vibrations
        M_matrix = np.array([-1.,1.,0.0,0.0,-1.,0.0,1.,0.0,-1.,0.0,0.0,1.,0.0,-1.,1.,0.0,0.0,-1.,0.0,1.,0.0,0.0,-1.,1.]);
        M_matrix = M_matrix.reshape((6,4))
        opdc_kalman_piezo_opd = np.dot(M_matrix,p2vmreduced.opdc_kalman_piezo.T).T # Time, baseline
                    
        PSD_k=[1,2,3,4,5,6]
        revcum_k = [1,2,3,4,5,6]
    
        for baseline in range(0,nbase):
            f_k, PSD_k[baseline] = welch(opdc_kalman_piezo_opd[:,baseline], fs=(1./np.nanmean(np.diff(p2vmreduced.opdc_time))),
                                     detrend='linear', nperseg=1024, scaling='spectrum')
            #PSD[baseline] -= np.nanmean(PSD[baseline][-100:])
            revcum_k[baseline] = np.sqrt(PSD_k[baseline][::-1].cumsum()[::-1])
            PSD_k[baseline] = np.sqrt(PSD_k[baseline]) * 1000.

        T0T1opd.plot(f_k,PSD_k[0],lw=0.5,color='mediumblue')
        T0T1opd.plot(f_k,10**(revcum_k[0])*ylimsopd[0],lw=0.5,color='cornflowerblue')
        T0T2opd.plot(f_k,PSD_k[1],lw=0.5,color='mediumblue')
        T0T2opd.plot(f_k,10**(revcum_k[1])*ylimsopd[0],lw=0.5,color='cornflowerblue')
        T0T3opd.plot(f_k,PSD_k[2],lw=0.5,color='mediumblue')
        T0T3opd.plot(f_k,10**(revcum_k[2])*ylimsopd[0],lw=0.5,color='cornflowerblue')
        T1T2opd.plot(f_k,PSD_k[3],lw=0.5,color='mediumblue')
        T1T2opd.plot(f_k,10**(revcum_k[3])*ylimsopd[0],lw=0.5,color='cornflowerblue')
        T1T3opd.plot(f_k,PSD_k[4],lw=0.5,color='mediumblue')
        T1T3opd.plot(f_k,10**(revcum_k[4])*ylimsopd[0],lw=0.5,color='cornflowerblue')
        T2T3opd.plot(f_k,PSD_k[5],lw=0.5,color='mediumblue')
        T2T3opd.plot(f_k,10**(revcum_k[5])*ylimsopd[0],lw=0.5,color='cornflowerblue')
            
        T0T1opd.set_xlim(xlimsopd);T0T2opd.set_xlim(xlimsopd);T0T3opd.set_xlim(xlimsopd);
        T1T2opd.set_xlim(xlimsopd);T1T3opd.set_xlim(xlimsopd);T2T3opd.set_xlim(xlimsopd);
        T0T1opd.set_ylim(ylimsopd);T0T2opd.set_ylim(ylimsopd);T0T3opd.set_ylim(ylimsopd);
        T1T2opd.set_ylim(ylimsopd);T1T3opd.set_ylim(ylimsopd);T2T3opd.set_ylim(ylimsopd);
        T0T1opd.set_xlim(xlims);T0T2opd.set_xlim(xlims);T0T3opd.set_xlim(xlims);
        T1T2opd.set_xlim(xlims);T1T3opd.set_xlim(xlims);T2T3opd.set_xlim(xlims);
        T0T1opd.text(0.95,0.85,"Piston %s (log) Rev. cumul. (linear)"%(p2vmreduced.basenames[0]),horizontalalignment='right',transform=T0T1opd.transAxes)
        T0T2opd.text(0.95,0.85,"Piston %s"%(p2vmreduced.basenames[1]),horizontalalignment='right',transform=T0T2opd.transAxes)
        T0T3opd.text(0.95,0.85,"Piston %s"%(p2vmreduced.basenames[2]),horizontalalignment='right',transform=T0T3opd.transAxes)
        T1T2opd.text(0.95,0.85,"Piston %s"%(p2vmreduced.basenames[3]),horizontalalignment='right',transform=T1T2opd.transAxes)
        T1T3opd.text(0.95,0.85,"Piston %s"%(p2vmreduced.basenames[4]),horizontalalignment='right',transform=T1T3opd.transAxes)
        T2T3opd.text(0.95,0.85,"Piston %s"%(p2vmreduced.basenames[5]),horizontalalignment='right',transform=T2T3opd.transAxes)
        T0T1opd.set_ylabel(r"$(nm)\ rms$")
        T0T3opd.set_ylabel(r"$(nm)\ rms$")
        T1T3opd.set_ylabel(r"$(nm)\ rms$")
        
        # Segment to show the scale rms of the linear reverse cumulative
        minseg = ylimsopd[0]*2
        maxseg = minseg*10
        T0T1opd.plot([xlimsopd[0]+(xlimsopd[1]-xlimsopd[0])/20., xlimsopd[0]+(xlimsopd[1]-xlimsopd[0])/20.],
                 [minseg,maxseg],color='cornflowerblue',lw=2,alpha=0.7)
        T0T1opd.text(0.02,0.35,r"1$\mu$m rms",horizontalalignment='left',rotation=90,transform=T0T1opd.transAxes,color='cornflowerblue')
    else:
        print("Cannot plot piston PSD, no Kalman OPD in file")


    #==============================================================================
    # Power spectral density of the FT flux signal vs. time
    #==============================================================================
        
    if (p2vmreduced.polarsplit == False):
        nmeasure = p2vmreduced.time_ft.shape[0]
        avgflux = np.zeros((4,nmeasure),dtype='d')
        for tel in range(0,4):
            avgflux[tel,:]=np.nanmean(p2vmreduced.oi_flux_ft[:,tel,:],axis=1) # average flux over wavelength

    if (p2vmreduced.polarsplit == True):
        nmeasure = p2vmreduced.time_ft.shape[0]
        avgfluxs = np.zeros((4,nmeasure),dtype='d')
        avgfluxp = np.zeros((4,nmeasure),dtype='d')
        avgflux  = np.zeros((4,nmeasure),dtype='d') # average flux vector of both polarizations for the power spectrum computation
        for tel in range(0,4):
            avgfluxs[tel,:]= np.nanmean(p2vmreduced.oi_flux_ft_s[:,tel,:],axis=1) # average flux over wavelength
            avgfluxp[tel,:]= np.nanmean(p2vmreduced.oi_flux_ft_p[:,tel,:],axis=1) # average flux over wavelength
            avgflux[tel,:] = np.nanmean([avgfluxs[tel,:],avgfluxp[tel,:]],axis=0) # average of both polarizations for the power spectrum computation
            
    PSD=[1,2,3,4]    
    revcum=[1,2,3,4]
    for tel in range(0,4):
        f, PSD[tel] = welch(avgflux[tel,:]/avgflux[tel,:].max(), fs=(1./np.mean(np.diff(p2vmreduced.time_ft))),
                                   detrend='linear', nperseg=1024, scaling='spectrum')
        #revcum[tel] = (PSD[tel]-np.mean(PSD[tel][-100:]))[::-1].cumsum()[::-1]
        revcum[tel] = np.sqrt(PSD[tel][::-1].cumsum()[::-1])
        PSD[tel] = np.sqrt(PSD[tel]) # relative flux rms

    T0phot.plot(f,PSD[0],lw=0.5,color='red')
    T0phot.plot(f,revcum[0],lw=0.5,color='darkorange')
    T1phot.plot(f,PSD[1],lw=0.5,color='red')            
    T1phot.plot(f,revcum[1],lw=0.5,color='darkorange')
    T2phot.plot(f,PSD[2],lw=0.5,color='red')            
    T2phot.plot(f,revcum[2],lw=0.5,color='darkorange')
    T3phot.plot(f,PSD[3],lw=0.5,color='red')            
    T3phot.plot(f,revcum[3],lw=0.5,color='darkorange')
    T0phot.set_ylim(ylimsphot); T0phot.set_xlim(xlims);
    T1phot.set_ylim(ylimsphot); T1phot.set_xlim(xlims);
    T2phot.set_ylim(ylimsphot); T2phot.set_xlim(xlims);
    T3phot.set_ylim(ylimsphot); T3phot.set_xlim(xlims);

    T0phot.text(0.9,0.85,"PHOT %s (log) Rev. cumul. (log)"%(p2vmreduced.stations[0]),horizontalalignment='right',transform=T0phot.transAxes)
    T1phot.text(0.9,0.85,"PHOT %s"%(p2vmreduced.stations[1]),horizontalalignment='right',transform=T1phot.transAxes)
    T2phot.text(0.9,0.85,"PHOT %s"%(p2vmreduced.stations[2]),horizontalalignment='right',transform=T2phot.transAxes)
    T3phot.text(0.9,0.85,"PHOT %s"%(p2vmreduced.stations[3]),horizontalalignment='right',transform=T3phot.transAxes)
    T0phot.set_ylabel(r"$Rel.\ flux\ rms$")
                                              
    #==============================================================================
    # Power spectral density of the METROLOGY unwrapped phase vs. time
    #==============================================================================

    if 'HIERARCH ESO INS MLC WAVELENG' in p2vmreduced.header:
        Lambda_met=p2vmreduced.header['HIERARCH ESO INS MLC WAVELENG']/1000. # Metrology wavelength in microns (1.908 microns)
    else: Lambda_met = 1.908287 # in microns

    
    resamp=1
    phasemet = np.zeros((p2vmreduced.oi_vis_met_phase[::resamp].shape[0],4),dtype='d')
    for tel in range(0,4):
        phasemet[:,tel] = p2vmreduced.oi_vis_met_phase[::resamp,tel]*Lambda_met/(2*np.pi)
        
    PSD_met=[1,2,3,4]
    revcum_met=[1,2,3,4]
    metfreq = 1./(resamp * (p2vmreduced.oi_vis_met_time[1]-p2vmreduced.oi_vis_met_time[0]))
    for tel in range(0,ntel):
        f_met, PSD_met[tel] = welch(np.unwrap(phasemet[:,tel]), fs=metfreq,
                                  detrend='linear', nperseg=1024, scaling='spectrum')
        revcum_met[tel] = np.sqrt(PSD_met[tel][::-1].cumsum()[::-1]) * 1000.
        PSD_met[tel] = np.sqrt(PSD_met[tel]) * 1000. # nm rms

    T0met.plot(f_met, PSD_met[0],lw=0.5,color='darkgreen')
    T0met.plot(f_met, revcum_met[0],lw=0.5,color='limegreen')
    T1met.plot(f_met, PSD_met[1],lw=0.5,color='darkgreen')
    T1met.plot(f_met, revcum_met[1],lw=0.5,color='limegreen')
    T2met.plot(f_met, PSD_met[2],lw=0.5,color='darkgreen')
    T2met.plot(f_met,revcum_met[2],lw=0.5,color='limegreen')
    T3met.plot(f_met, PSD_met[3],lw=0.5,color='darkgreen')
    T3met.plot(f_met, revcum_met[3],lw=0.5,color='limegreen')
    
    T0met.set_ylim(ylimsmet);T1met.set_ylim(ylimsmet);
    T2met.set_ylim(ylimsmet);T3met.set_ylim(ylimsmet);
    T0met.set_xlim(xlims);T1met.set_xlim(xlims);
    T2met.set_xlim(xlims);T3met.set_xlim(xlims);
    T0met.text(0.9,0.85,"MET %s (log) Rev. cumul. (log)"%(p2vmreduced.stations[0]),horizontalalignment='right',transform=T0met.transAxes)
    T1met.text(0.9,0.85,"MET %s"%(p2vmreduced.stations[1]),horizontalalignment='right',transform=T1met.transAxes)
    T2met.text(0.9,0.85,"MET %s"%(p2vmreduced.stations[2]),horizontalalignment='right',transform=T2met.transAxes)
    T3met.text(0.9,0.85,"MET %s"%(p2vmreduced.stations[3]),horizontalalignment='right',transform=T3met.transAxes)
    T0met.set_ylabel(r"$(nm)\ rms$")
    T0met.set_xlabel(r"$Hz$")
    T1met.set_xlabel(r"$Hz$")
    T2met.set_xlabel(r"$Hz$")
    T3met.set_xlabel(r"$Hz$")

    #==============================================================================
    #     Save report file
    #==============================================================================
    reportname = filename+"-VIBRATIONS.pdf"
    fig.savefig(reportname, dpi=150, bbox_inches='tight')


#==============================================================================
# MAIN PROGRAM    
#==============================================================================

if __name__ == '__main__':
    from . import gravi_visual_class
    filename=''
    # filename = '../GRAVITY.2015-11-16T00-06-33_0001'
    # filename = 'GRAVITY.2016-02-28T06-03-18_p2vmreduced'
    #filename='../test-files/GRAVI.2016-09-22T00-50-22.626_p2vmreddualsci'
    filename='../test-files/GRAVI.2017-07-23T23-22-9.453_singlescip2vmred'

    if len(sys.argv) == 2 :
        filename=sys.argv[1]
        
    if filename == '' :
        filename=input(" Enter P2VMREDUCED file name (without .fits) : ")

    p2vmreduced = gravi_visual_class.P2vmreduced(filename+'.fits')

    produce_vibrations_report(p2vmreduced,filename)


