# -*- coding: utf-8 -*-
"""
Created on Sat Jan 30 17:31:19 2016

@author: kervella
"""

import matplotlib as mpl
mpl.use('Agg') # necessary to be able to generate the reports even without X server
import matplotlib.pyplot as plt

import os
import sys
sys.path.insert(0, '~/Pipelines/python_tools/gravi_visual')
from . import gravi_visual_class
from . import gravi_visual_p2vm
import datetime
import numpy as np


def generate_p2vm_filelist(directory):
    filelist = []
    for file in os.listdir(directory):
        if file.endswith("_p2vm.fits"):
            filelist.append(directory+'/'+os.path.splitext(file)[0])    
    return filelist

if __name__ == '__main__':


    #==============================================================================
    #     User parameters
    #==============================================================================
    directory = '.'
    plot_sc = False
    specres = 'MEDIUM'
    polsetup = 'COMBINED'
    #==============================================================================
    
    filelist = generate_p2vm_filelist(directory)
    plt.close('all')

    # Key words in the header to be checked for file selection
    if plot_sc == True:
        combiner='SC'
        fits_keys = {'HIERARCH ESO PRO CATG': ['P2VM'],\
                     'HIERARCH ESO FT POLA MODE': [polsetup],\
                     'HIERARCH ESO INS SPEC RES': [specres]}
        setup = combiner + '-' + specres + '-' + polsetup
    else:
        combiner='FT'
        fits_keys = {'HIERARCH ESO PRO CATG': ['P2VM'],\
                     'HIERARCH ESO FT POLA MODE': [polsetup]}
        setup = combiner + '-ALL-' + polsetup
    
    p2vmlist = gravi_visual_class.P2vmlist(filelist[:],fits_keys)
    nfiles = p2vmlist.nfiles
    
    now = datetime.datetime.now()
    nbase = 6
    ntel = 4
    nfiles =  p2vmlist.nfiles
    nregion = p2vmlist.p2vm[0].nregion_sc
    nwave_sc = p2vmlist.p2vm[0].nwave_sc
    nwave_ft = p2vmlist.p2vm[0].nwave_ft
    base_list=["12", "13", "14", "23", "24", "34"]
    
    # The time scale of the series is set to start at zero
    timescale = (np.array(p2vmlist.timescale)-p2vmlist.timescale[0])*24. # in decimal hours

    p2vmstats = []
    coh_sc_avg = np.zeros((nfiles,nbase))
    coh_ft_avg = np.zeros((nfiles,nbase))
    coh_sc_avg_s = np.zeros((nfiles,nbase))
    coh_ft_avg_s = np.zeros((nfiles,nbase))
    coh_sc_avg_p = np.zeros((nfiles,nbase))
    coh_ft_avg_p = np.zeros((nfiles,nbase))
    trans_avg_ft = np.zeros((nfiles,nregion))
    trans_avg_sc = np.zeros((nfiles,nregion))
    
    for i in range(0,nfiles):
        p2vmstats.append(gravi_visual_p2vm.combiner_stats(p2vmlist.p2vm[i],base_list))
        coh_sc_avg[i,:] = np.asarray(p2vmstats[i]["coh_sc_avg"])
        coh_ft_avg[i,:] = np.asarray(p2vmstats[i]["coh_ft_avg"])
        trans_avg_ft[i,:] = np.mean(p2vmlist.p2vm[i].transmission_ft,axis=(1,2))
        trans_avg_sc[i,:] = np.mean(p2vmlist.p2vm[i].transmission_sc,axis=(1,2))
        if 'SPLIT' in setup:
            coh_sc_avg_s[i,:] = np.asarray(p2vmstats[i]["coh_sc_avg_s"])
            coh_ft_avg_s[i,:] = np.asarray(p2vmstats[i]["coh_ft_avg_s"])
            coh_sc_avg_p[i,:] = np.asarray(p2vmstats[i]["coh_sc_avg_p"])
            coh_ft_avg_p[i,:] = np.asarray(p2vmstats[i]["coh_ft_avg_p"])

    # Standard deviation of coherence and photometry as a function of time per output
    trans_ft_mean = np.mean(trans_avg_ft,axis=0)
    trans_sc_mean = np.mean(trans_avg_sc,axis=0)
    trans_ft_rms = np.std(trans_avg_ft,axis=0)
    trans_sc_rms = np.std(trans_avg_sc,axis=0)
    coh_ft_mean = np.mean(coh_ft_avg,axis=0)
    coh_sc_mean = np.mean(coh_sc_avg,axis=0)
    coh_ft_rms = np.std(coh_ft_avg,axis=0)
    coh_sc_rms = np.std(coh_sc_avg,axis=0)
    if 'SPLIT' in setup:
        coh_ft_mean_s = np.mean(coh_ft_avg_s,axis=0)
        coh_sc_mean_s = np.mean(coh_sc_avg_s,axis=0)
        coh_ft_mean_p = np.mean(coh_ft_avg_p,axis=0)
        coh_sc_mean_p = np.mean(coh_sc_avg_p,axis=0)
        coh_ft_rms_s = np.std(coh_ft_avg_s,axis=0)
        coh_sc_rms_s = np.std(coh_sc_avg_s,axis=0)
        coh_ft_rms_p = np.std(coh_ft_avg_p,axis=0)
        coh_sc_rms_p = np.std(coh_sc_avg_p,axis=0)

    #==============================================================================
    # Production of the figures
    #==============================================================================

    import matplotlib.backends.backend_pdf
    pdf = matplotlib.backends.backend_pdf.PdfPages('P2VM-trend-'+setup+'.pdf')
        
    if plot_sc == False:

        if 'COMBINED' in setup:
            plt.figure(figsize=(10,10))
            for base in range(0,nbase):
                labelrms = 'B'+base_list[base]+' %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":coh_ft_mean[base]*100.,"rmsval":coh_ft_rms[base]*100.}
                plt.plot(timescale,coh_ft_avg[:,base], label=labelrms, ls='-', marker='+')
            plt.title("FT avg coherence transmission per baseline "+setup)
            plt.xlabel("Time (hours)")
            plt.ylabel("Coherence transmission")
            plt.grid()
            plt.margins(0.2)
            axes = plt.gca()
            axes.set_xlim([-50,400])
            axes.set_ylim([0.8,1.05])
            plt.legend(bbox_to_anchor=(0., 0., 1., .102), loc=3,
                   ncol=3, mode="expand", borderaxespad=1,prop={'size':12})
            pdf.savefig()
        
        if 'SPLIT' in setup:
            plt.figure(figsize=(10,10))
            for base in range(0,nbase):
                labelrms = 'B'+base_list[base]+'-S %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":coh_ft_mean_s[base]*100.,"rmsval":coh_ft_rms_s[base]*100.}
                plt.plot(timescale,coh_ft_avg_s[:,base],label=labelrms, ls='-', marker='+')
                labelrms = 'B'+base_list[base]+'-P %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":coh_ft_mean_p[base]*100.,"rmsval":coh_ft_rms_p[base]*100.}
                plt.plot(timescale,coh_ft_avg_p[:,base],label=labelrms, ls='-', marker='x')
            plt.title("FT avg coherence transmission per baseline (S and P) "+setup)
            plt.xlabel("Time (hours)")
            plt.ylabel("Coherence transmission")
            plt.grid()
            axes = plt.gca()
            axes.set_xlim([-50,400])
            axes.set_ylim([0.8,1.05])
            plt.legend(bbox_to_anchor=(0., 0., 1., .102), loc=3,
                   ncol=3, mode="expand", borderaxespad=1,prop={'size':12})
            pdf.savefig()
    
        plt.figure(figsize=(10,10))
        for region in range(0,nregion):
            labelrms = p2vmlist.p2vm[0].regname_ft[region]+' %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":trans_ft_mean[region]*100.,"rmsval":trans_ft_rms[region]*100.}
            plt.plot(timescale,trans_avg_ft[:,region], label=labelrms, ls='-', marker='+')
        plt.title("FT relative photometric transmission per IO output "+setup)
        plt.xlabel("Time (hours)")
        plt.ylabel("Photometric transmission")
        plt.grid()
        plt.margins(0.2)
        axes = plt.gca()
        axes.set_xlim([-50,400])
        axes.set_ylim([0.6,1.3])
        plt.legend(bbox_to_anchor=(0., 0., 1., .102), loc=3,
               ncol=4, mode="expand", borderaxespad=1,prop={'size':8})
        pdf.savefig()

    if plot_sc == True:
        
        if 'COMBINED' in setup:
            plt.figure(figsize=(10,10))
            for base in range(0,nbase):
                labelrms = 'B'+base_list[base]+' %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":coh_sc_mean[base]*100.,"rmsval":coh_sc_rms[base]*100.}
                plt.plot(timescale,coh_sc_avg[:,base], label=labelrms, ls='-', marker='+')
            plt.title("SC avg coherence transmission per baseline "+setup)
            plt.xlabel("Time (hours)")
            plt.ylabel("Coherence transmission")
            plt.grid()
            plt.margins(0.2)
            axes = plt.gca()
            axes.set_xlim([-50,400])
            axes.set_ylim([0.8,1.05])
            plt.legend(bbox_to_anchor=(0., 0., 1., .102), loc=3,
                   ncol=3, mode="expand", borderaxespad=1,prop={'size':12})
            pdf.savefig()
        
        if 'SPLIT' in setup:
            plt.figure(figsize=(10,10))
            for base in range(0,nbase):
                labelrms = 'B'+base_list[base]+'-S %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":coh_sc_mean_s[base]*100.,"rmsval":coh_sc_rms_s[base]*100.}
                plt.plot(timescale,coh_sc_avg_s[:,base],label=labelrms, ls='-', marker='+')
                labelrms = 'B'+base_list[base]+'-P %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":coh_sc_mean_p[base]*100.,"rmsval":coh_sc_rms_p[base]*100.}
                plt.plot(timescale,coh_sc_avg_p[:,base],label=labelrms, ls='-', marker='x')
            plt.title("SC avg coherence transmission per baseline (S and P) "+setup)
            plt.xlabel("Time (hours)")
            plt.ylabel("Coherence transmission")
            plt.grid()
            axes = plt.gca()
            axes.set_xlim([-50,400])
            axes.set_ylim([0.8,1.05])
            plt.legend(bbox_to_anchor=(0., 0., 1., .102), loc=3,
                   ncol=3, mode="expand", borderaxespad=1,prop={'size':12})
            pdf.savefig()

        plt.figure(figsize=(10,10))
        for region in range(0,nregion):
            labelrms = p2vmlist.p2vm[0].regname_sc[region]+' %(avgval).1f%% +/- %(rmsval).1f%%' % {"avgval":trans_sc_mean[region]*100.,"rmsval":trans_sc_rms[region]*100.}
            plt.plot(timescale,trans_avg_sc[:,region], label=labelrms, ls='-', marker='+')
        plt.title("SC relative photometric transmission per IO output "+setup)
        plt.xlabel("Time (hours)")
        plt.ylabel("Photometric transmission")
        plt.grid()
        plt.margins(0.2)
        axes = plt.gca()
        axes.set_xlim([-50,400])
        axes.set_ylim([0.6,1.3])
        plt.legend(bbox_to_anchor=(0., 0., 1., .102), loc=3,
               ncol=4, mode="expand", borderaxespad=1,prop={'size':8})
        pdf.savefig()
        


    pdf.close()
    
    
    
    
    
    
    
    
