#! /usr/bin/env python3
# -*- coding: iso-8859-15 -*-

import matplotlib as mpl
mpl.use('Agg') # necessary to be able to generate the reports even without X server
import matplotlib.pyplot as plt
#import matplotlib.colors as colors

import os
import datetime
try:
   import pyfits
except:
   from astropy.io import fits as pyfits
import numpy as np
from . import gravi_visual_class



#==============================================================================
# Preparation of the report using reportlab
#==============================================================================
def produce_wave_report(wave,filename):
    # From gravi_visual_class
    from . import gravi_visual_class
    from .gravi_visual_class import myFirstPage, myLaterPages, styles
    from .gravi_visual_class import clipdata, transbarchartout, graphoutaxes, graphoutnoaxis
    from .gravi_visual_class import graphoutnoxaxes, graphscatteraxes, graphaxesplt
    from .gravi_visual_class import plotTitle, plotSubtitle
    from .gravi_visual_class import get_key_withdefault, create_array_from_list, baseline_phases
    from .gravi_visual_class import mean_angle, std_angle, clean_unwrap_phase, nanaverage
    from .gravi_visual_class import clean_gdelay_is, clean_gdelay_fft, clean_gdelay_full
    from .gravi_visual_class import base_list, tel_list, nbase, ntel
    from .gravi_visual_class import plotReductionSummary, plotSummary
    
    from .gravi_visual_class import stepbarchartout, get_reg_number, get_base_number

    from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image, PageBreak
    from reportlab.graphics.shapes import String
    from reportlab.lib.units import inch, cm, mm
    from reportlab.lib import colors
    import io

    from io import StringIO, BytesIO
    from svglib.svglib import svg2rlg

    #==============================================================================
    # Global parameters
    #==============================================================================
    
    #startx_mr = 53
    #startx_hr = 38

    delta_res = 0.005 if wave.nwave_sc > 50 else 0.05

    #==============================================================================
    # Start Story
    #==============================================================================
    
    Story = [Spacer(1,1*mm)]
    plotSummary (Story, filename, wave, onTarget=False)
    Story.append(PageBreak())

    #==============================================================================
    # Summary of reduction parameters
    #==============================================================================

    plotReductionSummary(Story, wave)
    Story.append(PageBreak())
         
    #==============================================================================
    # FT wavelength NON-INTERPOLATED vector map (microns) - average
    #==============================================================================

    plotTitle(Story,"FT wavelength variation map")
    plotSubtitle(Story,"Average wavelength scale over outputs subtracted, deviation in percent of the pixel size.")
    Story.append(Spacer(1,2*mm))
    
    meanchannel = np.mean(wave.wave_ft[:,3]-wave.wave_ft[:,2])
    waveresidual = np.copy(wave.wave_ft)
    for channel in range(0,wave.wave_ft.shape[0]):
        waveresidual[channel,:] = 100.*(wave.wave_ft[channel,:] - np.mean(wave.wave_ft[:,:],axis=0))/meanchannel

    plt.close('all')
    fig=plt.figure(figsize=(6,9))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    maxdiff = np.max(np.abs(waveresidual))
    plt.imshow(waveresidual,vmin=-maxdiff,vmax=maxdiff,cmap='cubehelix', interpolation='nearest', aspect='auto')
    plt.colorbar(fraction=0.15,aspect=15,orientation='horizontal',shrink=0.7,pad=0.05)

    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)
    
    Story.append(PageBreak())
    
    #==============================================================================
    # Image display of the SC spectrum extraction mask
    #==============================================================================

    plotTitle(Story,"SC spectrum extraction mask")
    Story.append(Spacer(1,2*mm))
    
    plt.close('all')
    fig=plt.figure(figsize=(6,9))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    plt.imshow(wave.test_wave[1,:,:],vmin=0,vmax=0.5,cmap='cubehelix', interpolation='nearest', aspect='auto')
    plt.colorbar(fraction=0.15,aspect=15,orientation='horizontal',shrink=0.7,pad=0.05)
    for i in range(0,6):
        text = plt.annotate("Base "+wave.base_list_sc[i],xy=(int(wave.test_wave.shape[0]/10), int(32+wave.test_wave.shape[1]*i/6)),  xycoords='data', fontsize=6, color='white')
        fig.gca().add_artist(text)

    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(PageBreak())
    
    #==============================================================================
    # SC wavelength NON-INTERPOLATED vector map (microns)
    #==============================================================================

    plotTitle(Story,"SC wavelength NON-INTERPOLATED vector map (microns)")
    Story.append(Spacer(1,2*mm))

    plt.close('all')
    fig=plt.figure(figsize=(6,9))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    plt.imshow(wave.test_wave[2,:,:],vmin=wave.minwave_sc,vmax=wave.maxwave_sc,cmap='flag', interpolation='none', aspect='auto')
    plt.colorbar(fraction=0.15,aspect=15,orientation='horizontal',shrink=0.7,pad=0.05)
    for i in range(0,6):
        text = plt.annotate("Base "+wave.base_list_sc[i],xy=(int(wave.test_wave.shape[0]/10), int(32+wave.test_wave.shape[1]*i/6)),  xycoords='data', fontsize=6, color='white')
        fig.gca().add_artist(text)
        
    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(PageBreak())
    
    #==============================================================================
    # SC wavelength INTERPOLATED vector map (microns)
    #==============================================================================

    plotTitle(Story,"SC wavelength INTERPOLATED vector map (microns)")
    Story.append(Spacer(1,2*mm))
    
    plt.close('all')
    fig=plt.figure(figsize=(6,9))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    plt.imshow(wave.test_wave[0,:,:],vmin=wave.minwave_sc,vmax=wave.maxwave_sc,cmap='flag', interpolation='none', aspect='auto')
    plt.colorbar(fraction=0.15,aspect=15,orientation='horizontal',shrink=0.7,pad=0.05)
    for i in range(0,6):
        text = plt.annotate("Base "+wave.base_list_sc[i],xy=(int(wave.test_wave.shape[0]/10), int(32+wave.test_wave.shape[1]*i/6)),  xycoords='data', fontsize=6, color='white')
        fig.gca().add_artist(text)

    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(PageBreak())
    
    #==============================================================================
    # Offset of SC wavelength scale [INTERP. - NON-INTERP.] (nm)
    #==============================================================================

    plotTitle(Story,"Offset of SC wavelength scale [INTERP. - NON-INTERP.] (nm)")
    Story.append(Spacer(1,2*mm))
    
    plt.close('all')
    fig=plt.figure(figsize=(6,9))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    offset_interp = (wave.test_wave[0,:,:]-wave.test_wave[2,:,:]) * 1000. # in nm
    valmax = np.percentile(np.abs(offset_interp),90)
    valmin = -valmax
    plt.imshow(offset_interp,vmin=valmin,vmax=valmax,cmap='seismic', interpolation='none', aspect='auto')
    plt.colorbar(fraction=0.15,aspect=15,orientation='horizontal',shrink=0.7,pad=0.05)
    for i in range(0,6):
        text = plt.annotate("Base "+wave.base_list_sc[i],xy=(int(wave.test_wave.shape[0]/10), int(32+wave.test_wave.shape[1]*i/6)),  xycoords='data', fontsize=6, color='black')
        fig.gca().add_artist(text)
    plt.savefig(imgdata, format='PDF', dpi=250, bbox_inches='tight')

    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(PageBreak())

    #==============================================================================
    # SC wavelengthvector map (microns) - average
    #==============================================================================

    y = np.mean(wave.wave_sc,axis=0)
    x = np.arange(len(y))
    mean = np.poly1d(np.polyfit(x[3:-2], y[3:-2], 2))(x)

    plotTitle(Story,"SC wavelength (WAVE_FIBRE_SC) for each base and channel")
    plotSubtitle(Story,"region in horizontal, wavelenght in vertical [um], channel in color")

    plt.close('all')
    fig=plt.figure(figsize=(6,3))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    color = ['b','g','r','c','m','y']
    for l in range(6):
        if wave.polarsplit == False:
            plt.plot (wave.wavefiber_c[l,:], c=color[l],ls='-')
            ids = get_reg_number(wave.regname_sc,l,"C")
            plt.plot (np.mean(wave.wave_sc[ids,:],axis=0), c=color[l],ls='--')
        else:
            plt.plot (wave.wavefiber_s[l,:])
            ids = get_reg_number(wave.regname_sc,l,"S")
            plt.plot (np.mean(wave.wave_sc[ids,:],axis=0), c=color[l],ls='--')
    plt.ylim([1.8,2.6])

    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(Spacer(1,1*cm))

    plotSubtitle(Story,"the same, but a common mean 2sd order subtracted, region in color")

    plt.close('all')
    fig=plt.figure(figsize=(6,3))
    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    color = ['b','g','r','c','m','y']
    for l in range(6):
        if wave.polarsplit == False:
            plt.plot (wave.wavefiber_c[l,:] - mean, c=color[l],ls='-')
            ids = get_reg_number(wave.regname_sc,l,"C")
            plt.plot (np.mean(wave.wave_sc[ids,:],axis=0) - mean, c=color[l],ls='--')
        else:
            plt.plot (wave.wavefiber_s[l,:] - mean, c=color[l],ls='-')
            ids = get_reg_number(wave.regname_sc,l,"S")
            plt.plot (np.mean(wave.wave_sc[ids,:],axis=0) - mean, c=color[l],ls='--')
    plt.ylim([-delta_res,delta_res])

    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(PageBreak())

    #==============================================================================
    # SC wavelengthvector map (microns) - average
    #==============================================================================

    plotTitle(Story,"SC wavelength (WAVE_DATA_SC) for each region and channel")
    plotSubtitle(Story,"channel in horizontal, wavelenght in vertical [um], region in color")

    plt.close('all')
    fig=plt.figure(figsize=(6,3))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    for l in range(len(wave.wave_sc[:,0])):
        if 'P' in wave.regname_sc[l]:
            continue
        b = get_base_number (wave.regname_sc, l)
        plt.plot (wave.wave_sc[l,:], c=color[b])
    plt.ylim([1.8,2.6])
    
    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(Spacer(1,1*cm))

    plotSubtitle(Story,"the same, but a common mean 2sd order subtracted, region in color")

    plt.close('all')
    fig=plt.figure(figsize=(6,3))

    plt.rc('xtick', labelsize=5)
    plt.rc('ytick', labelsize=5)
    for l in range(len(wave.wave_sc[:,0])):
        if 'P' in wave.regname_sc[l]:
            continue
        b = get_base_number (wave.regname_sc, l)
        plt.plot (wave.wave_sc[l,:] - mean, c=color[b])
    plt.ylim([-delta_res,delta_res])
    
    imgdata = BytesIO()
    fig.savefig(imgdata, format='svg', bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    drawing = svg2rlg(imgdata)
    Story.append(drawing)

    Story.append(PageBreak())
    

    #==============================================================================
    # FT wavelengthvector map (microns) - average
    #==============================================================================

    if (wave.nwave_sc < 1000):
        plotTitle(Story,"SC wavelength (WAVE_DATA_SC) for each region and channel")
        plotSubtitle(Story,"region in horizontal, wavelenght in vertical [um], channel in color")

        plt.close('all')
        fig=plt.figure(figsize=(6,9))

        plt.rc('xtick', labelsize=5)
        plt.rc('ytick', labelsize=5)
        
        for l in range(wave.nwave_sc):
            plt.plot (wave.wave_sc[:,l])
        plt.ylim([1.8,2.6])
    
        imgdata = BytesIO()
        fig.savefig(imgdata, format='svg', bbox_inches='tight')
        imgdata.seek(0)  # rewind the data
        drawing = svg2rlg(imgdata)
        Story.append(drawing)

        Story.append(PageBreak())

    #==============================================================================
    # Spectra step on detectors in pixels
    #==============================================================================

    plotTitle(Story,"Vertical step of SC extracted spectra (pixels)")
    Story.append(Spacer(1,2*mm))
    
    hsize = 18*cm
    vsize = 4*cm
    step_sc = np.zeros(wave.nregion_sc-1)
    for i in range(1,wave.nregion_sc):
        step_sc[i-1] = wave.center_sc[i,1]-wave.center_sc[i-1,1]
    
    step_ft = np.zeros(wave.nregion_ft-1)
    for i in range(1,wave.nregion_ft):
        step_ft[i-1] = wave.center_ft[i,1]-wave.center_ft[i-1,1]
    
    data = [tuple(step_sc)]
    labelout = wave.regname_sc.astype('|S6').tolist()
    a = stepbarchartout(data, labelout, np.min(step_sc)-1, np.max(step_sc)+1, 0.5, hsize, vsize, colors.cornflower)
    Story.append(a)
    
    Story.append(Spacer(1,1*cm))
    
    plotTitle(Story,"Vertical step of FT extracted spectra (pixels)")
    Story.append(Spacer(1,2*mm))
    
    data = [tuple(step_ft)]
    #a = graphoutaxes( data, 1, wave.nregion_ft, 3, 6, 15*cm, 4*cm, '')
    labelout = wave.regname_ft.astype('|S6').tolist()
    a = stepbarchartout(data, labelout, np.min(step_ft)-1, np.max(step_ft)+1, 0.5, hsize, vsize, colors.cornflower)
    Story.append(a)
    
    Story.append(PageBreak())
   
    #==============================================================================
    # Create report from Story
    #==============================================================================
    print("Create the PDF")
    
    gravi_visual_class.TITLE = "GRAVITY WAVE Quality Control Report"
    gravi_visual_class.PAGEINFO = "WAVE file: "+filename+".fits"
    reportname = filename+"-WAVE.pdf"
    
    doc = SimpleDocTemplate(reportname)
    doc.build(Story, onFirstPage=myFirstPage, onLaterPages=myLaterPages)

    print((" "+reportname)) 

#==============================================================================
# MAIN PROGRAM    
#==============================================================================

if __name__ == '__main__':
    import sys
    filename=''

    if len(sys.argv) == 2 :
        filename=sys.argv[1]
        
    if filename == '' :
        filename=input(" Enter WAVE file name (without .fits) : ")

    wave = gravi_visual_class.Wave(filename+'.fits')

    produce_wave_report(wave,filename)

