#! /usr/bin/env python
# -*- coding: iso-8859-15 -*-
"""
Created on Mon May 25 13:51:41 2015

@author: kervella
"""

# ATTENTION: necessary to install pdfrw library (sudo port install py27-pdfrw)

import matplotlib as mpl
mpl.use('Agg') # necessary to be able to generate the reports even without X server
import matplotlib.pyplot as plt

import sys
import datetime
import numpy as np

try:
    from scipy.signal import welch
    welchpresent = True
except ImportError as e:
    welchpresent = False
    pass # module doesn't exist, deal with it.
     
#==============================================================================
# Preparation of the report using reportlab
#==============================================================================
def produce_preproc_report(preproc,filename):
    # From gravi_visual_class
    from . import gravi_visual_class
    from .gravi_visual_class import PdfImage, myFirstPage, myLaterPages, styles
    from .gravi_visual_class import clipdata, transbarchartout, graphoutaxes, graphoutnoaxis
    from .gravi_visual_class import graphoutnoxaxes, graphscatteraxes, graphaxesplt
    from .gravi_visual_class import plotTitle, plotSubtitle
    from .gravi_visual_class import get_key_withdefault, create_array_from_list, baseline_phases
    from .gravi_visual_class import mean_angle, std_angle, clean_unwrap_phase, nanaverage
    from .gravi_visual_class import clean_gdelay_is, clean_gdelay_fft, clean_gdelay_full
    from .gravi_visual_class import base_list, tel_list, nbase, ntel
    from .gravi_visual_class import plotReductionSummary, plotSummary

    # From reportlab
    from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image, PageBreak
    from reportlab.graphics.shapes import String
    from reportlab.lib.units import inch, cm, mm
    from reportlab.lib import colors
    
    # import PIL
    import io
    from io import StringIO

    #==============================================================================
    # Global parameters
    #==============================================================================
    
    pixwidth = 1500 # Number of elements for plots in horizontal direction
    
    #==============================================================================
    # Start Story
    #==============================================================================
    Story = [Spacer(1,5*mm)]
    plotSummary (Story, filename, preproc, onTarget=False)

    #==============================================================================
    # Photometry per output
    #==============================================================================
        
    Story.append(Spacer(1,1*cm))
    
    if preproc.spectrum_sc is not None:
        plotTitle (Story,"SC average photometric flux per output (ADUs)")
        
        hsize = 18*cm
        vsize = 5*cm
        # Grand average transmission
        # trans_avg_sc = np.mean(preproc.spectrum_sc)
        # Average transmission per region normalized to grand median
        trans_region_sc = np.zeros([preproc.nregion],dtype='d')
        for region in range(0,preproc.nregion):
            trans_region_sc[region]=np.mean(preproc.spectrum_sc[region,:,:])
        #    trans_region_sc[region]=np.mean(preproc.spectrum_sc[region,:,:])/trans_avg_sc
        # plot the two polarizations separately (alternatively S and P)
        yminval=np.min(trans_region_sc)
        ymaxval=np.max(trans_region_sc)+0.1*np.abs(np.max(trans_region_sc))
        ystep=(ymaxval-yminval)/10.
        if preproc.polarsplit == True:
            transdata = [tuple(clipdata(trans_region_sc[::2],yminval,ymaxval)),tuple(clipdata(trans_region_sc[1::2],yminval,ymaxval))]
            labelbase = preproc.regname_sc[::2].astype('|S4').tolist()
            bc = transbarchartout(transdata,labelbase,yminval,ymaxval,ystep,hsize,vsize,colors.aquamarine,colors.cornflower)
        else:
            transdata = [tuple(clipdata(trans_region_sc,yminval,ymaxval))]
            labelbase = preproc.regname_sc.astype('|S4').tolist()
            bc = transbarchartout(transdata,labelbase,yminval,ymaxval,ystep,hsize,vsize,colors.aquamarine,colors.cornflower)
        Story.append(bc)
    
        Story.append(Spacer(1,5*mm))

    if preproc.spectrum_ft is not None:
        plotTitle (Story,"FT average photometric flux per output (ADUs)")
        
        hsize = 18*cm
        vsize = 5*cm
        # Grand average transmission
        # trans_avg_ft = np.mean(preproc.spectrum_ft)
        # Average transmission per region normalized to grand median
        trans_region_ft = np.zeros([preproc.nregion],dtype='d')
        for region in range(0,preproc.nregion):
            trans_region_ft[region]=np.mean(preproc.spectrum_ft[region,:,:])
        #    trans_region_ft[region]=np.mean(preproc.spectrum_ft[region,:,:])/trans_avg_ft
        # plot the two polarizations separately (alternatively S and P)
        yminval=np.min(trans_region_ft)
        ymaxval=np.max(trans_region_ft)+0.1*np.abs(np.max(trans_region_ft))
        ystep=(ymaxval-yminval)/10.
        if preproc.polarsplit == True:
            transdata = [tuple(clipdata(trans_region_ft[::2],yminval,ymaxval)),tuple(clipdata(trans_region_ft[1::2],yminval,ymaxval))]
            labelbase = preproc.regname_ft[::2].astype('|S4').tolist()
            bc = transbarchartout(transdata,labelbase,yminval,ymaxval,ystep,hsize,vsize,colors.aquamarine,colors.cornflower)
        else:
            transdata = [tuple(clipdata(trans_region_ft,yminval,ymaxval))]
            labelbase = preproc.regname_ft.astype('|S4').tolist()
            bc = transbarchartout(transdata,labelbase,yminval,ymaxval,ystep,hsize,vsize,colors.aquamarine,colors.cornflower)
        Story.append(bc)
        
    Story.append(PageBreak())

    #==============================================================================
    # Summary of reduction parameters
    #==============================================================================

    plotReductionSummary(Story, preproc)
    Story.append(PageBreak())

    #==============================================================================
    # SC spectrum zoomed into spectral regions
    #==============================================================================
    plotTitle (Story,"SC spectrums (%i regions) zoomed into spectral features"%preproc.nregion)

    mean_spectrum_sc = np.mean (preproc.spectrum_sc, axis=1).T
    mean_spectrum_sc /= np.mean (mean_spectrum_sc, axis=0)[None,:]
    
    # List of plots
    wave_limits = [[1.99,2.45],[1.99,2.03],[2.15,2.19],[2.28,2.32]]
    nplot = len (wave_limits)

    imgdata = io.StringIO()
    plt.close('all')
    plt.rc('xtick', labelsize='x-small')
    plt.rc('ytick', labelsize='x-small')
    
    fig, axarr = plt.subplots (nplot, 1, sharex=False, sharey=False, figsize=(5,8))
    for ax,lim in zip (axarr, wave_limits):
        ax.plot (preproc.wave_sc, mean_spectrum_sc)
        ax.set_xlim([lim[0], lim[1]])
        ax.set_ylim(ymin=0)
        
    plt.savefig(imgdata, format='PDF', dpi=250, bbox_inches='tight')
    Story.append(PdfImage(imgdata, width=20*cm, height=22*cm, kind='bound'))
    Story.append(PageBreak())
        
    #==============================================================================
    # Image (waterfall) display of SC and FT spectral signals
    #==============================================================================

    if preproc.spectrum_ft is not []:
        nframes = 5000 # Number of frames to display
        for i in range(0,preproc.nregion,8):
            plotTitle (Story,"FT baseline "+preproc.regname_ft[i][0:2]+" flux (ADU, "+str(nframes)+" frames)")
            plotSubtitle (Story,"Wavelength channel number in horizontal axis, number of frame in sequence in vertical axis.")

            imgdata = io.StringIO()
            plt.close('all')
            plt.rc('xtick', labelsize='x-small')
            plt.rc('ytick', labelsize='x-small')
            fig, axarr = plt.subplots(nrows=4, ncols=2, sharex=True, sharey=True, figsize=(5,8))
            valmin = np.percentile(preproc.spectrum_ft[i:i+8,0:nframes,:],1)
            valmax = np.percentile(preproc.spectrum_ft[i:i+8,0:nframes,:],99)
            for j in range(0,4):
                im = axarr[j,0].imshow(preproc.spectrum_ft[2*j+i,0:nframes,:],vmin=valmin,vmax=valmax,cmap='cubehelix', interpolation='nearest',aspect='auto')
                axarr[j,0].set_title('FT Output '+preproc.regname_sc[2*j+i], fontsize=7)
                im = axarr[j,1].imshow(preproc.spectrum_ft[2*j+1+i,0:nframes,:],vmin=valmin,vmax=valmax,cmap='cubehelix', interpolation='nearest', aspect='auto')
                axarr[j,1].set_title('FT Output '+preproc.regname_sc[2*j+i+1], fontsize=7)
            fig.subplots_adjust(hspace=.2)
            # Color bar plot
            fig.subplots_adjust(right=0.8)
            cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
            fig.colorbar(im, cax=cbar_ax)
            #fig.tight_layout()
            #plt.colorbar()
            widthfig = 20 * cm
            heightfig= 22 * cm
            plt.savefig(imgdata, format='PDF', dpi=250, bbox_inches='tight')
            pi1 = PdfImage(imgdata, width=widthfig, height=heightfig, kind='bound')
            plt.close('all')
            Story.append(pi1)
            plt.close('all')
         
            Story.append(PageBreak())

    if preproc.spectrum_sc is not []:
        for i in range(0,preproc.nregion,8):
            plotTitle (Story,"SC baseline "+preproc.regname_sc[i][0:2]+" flux (ADU) waterfall view")
            plotSubtitle (Story,"Wavelength channel number in horizontal axis, number of frame in sequence in vertical axis.")
            
            imgdata = io.StringIO()
            plt.close('all')
            plt.rc('xtick', labelsize='x-small')
            plt.rc('ytick', labelsize='x-small')
            fig, axarr = plt.subplots(nrows=4, ncols=2, sharex=True, sharey=True, figsize=(5,8))
            valmin = np.percentile(preproc.spectrum_sc[i:i+8,0:nframes,:],1)
            valmax = np.percentile(preproc.spectrum_sc[i:i+8,0:nframes,:],99)
            for j in range(0,4):
                im = axarr[j,0].imshow(preproc.spectrum_sc[2*j+i,0:nframes,:],vmin=valmin,vmax=valmax,cmap='cubehelix', interpolation='nearest',aspect='auto')
                axarr[j,0].set_title('SC Output '+preproc.regname_sc[2*j+i], fontsize=7)
                im = axarr[j,1].imshow(preproc.spectrum_sc[2*j+1+i,0:nframes,:],vmin=valmin,vmax=valmax,cmap='cubehelix', interpolation='nearest', aspect='auto')
                axarr[j,1].set_title('SC Output '+preproc.regname_sc[2*j+i+1], fontsize=7)
            fig.subplots_adjust(hspace=.2)
            # Color bar plot
            fig.subplots_adjust(right=0.8)
            cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
            fig.colorbar(im, cax=cbar_ax)
            #fig.tight_layout()
            #plt.colorbar()
            widthfig = 20 * cm
            heightfig= 22 * cm
            plt.savefig(imgdata, format='PDF', dpi=250, bbox_inches='tight')
            pi1 = PdfImage(imgdata, width=widthfig, height=heightfig, kind='bound')
            plt.close('all')
            Story.append(pi1)
            plt.close('all')
            Story.append(PageBreak()) 

    #==============================================================================
    # Create PDF report from Story
    #==============================================================================
    print("Create the PDF")
            
    gravi_visual_class.TITLE = "GRAVITY PREPROC Quality Control Report"
    gravi_visual_class.PAGEINFO = "PREPROC file: "+filename+".fits"
    reportname = filename+"-PREPROC.pdf"
    
    doc = SimpleDocTemplate(reportname)
    doc.build(Story, onFirstPage=myFirstPage, onLaterPages=myLaterPages)
    print((" "+reportname)) 

#==============================================================================
# MAIN PROGRAM    
#==============================================================================

if __name__ == '__main__':
    filename=''
    #filename='preproc_sample'
    #filename = 'p2vm_preproc_GRAV.2015-04-30T11-24-45'
    #filename = 'p2vm_reduce_preproc_5'

    if len(sys.argv) == 2 :
        filename=sys.argv[1]
        
    if filename == '' :
        filename=input(" Enter PREPROC file name (without .fits) : ")

    preproc = gravi_visual_class.Preproc(filename+'.fits')

    produce_preproc_report(preproc,filename)


