# -*- coding: utf-8 -*-
"""
Created on Wed May 11 13:57:16 2016

@author: kervella
"""

import numpy as np

#==============================================================================
# Functions
#==============================================================================

def get_key_withdefault(dict, key, default):
    if key in list(dict.keys()):
        return dict[key]
    if key in ["HIERARCH "+c for c in list(dict.keys())]:
        return dict[key]
    else:
        return default

# mean angle of a numpy array of angles in radians or in degrees (if deg=True)
# computes the average angle using the sum of the phasors
def mean_angle(angles,axis=None,deg=False):
    if deg == True:
        angles = (angles % 360) * np.pi()/360. # Bring the angles between 0 and 360
        phasors = np.exp(1j*angles*np.pi/180.)
        meanangle = np.angle(np.sum(phasors,axis=axis))*180./np.pi
    else:
        angles = angles % (2*np.pi) # Bring the angles between 0 and 2pi
        phasors = np.exp(1j*angles)
        meanangle = np.angle(np.sum(phasors,axis=axis))
    return meanangle

# subtracts mean angle of a numpy array of angles in radians or in degrees (if deg=True)
# computes the average angle using the sum of the phasors
def sub_mean_angle(angles,axis=None,deg=False):
    if deg == True:
        angles = (angles % 360) * np.pi()/360. # Bring the angles between 0 and 360
        phasors = np.exp(1j*angles*np.pi/180.)
        meanphasor = np.mean(phasors,axis=axis)
        phasors /= meanphasor
        meanangle = np.angle(phasors)*180./np.pi
    else:
        angles = angles % (2*np.pi) # Bring the angles between 0 and 2pi
        phasors = np.exp(1j*angles)
        meanphasor = np.mean(phasors,axis=axis)
        phasors /= meanphasor
        meanangle = np.angle(phasors)
    return meanangle

# standard deviation of a numpy array of angles in radians or in degrees (if deg=True)
def std_angle(angles,axis=None,deg=False):
    if deg == True:
        angles = angles % 360 # Bring the angles between 0 and 360
        #uw_angles = np.unwrap(angles*np.pi/180.,axis=axis)
        #stdangle = np.nanstd(uw_angles,axis=axis)*180./np.pi
        complex_angles = np.cos(angles*np.pi/180.) + np.sin(angles*np.pi/180.)*1j
        stdangle = np.nanstd(np.angle(complex_angles),axis=axis)*180./np.pi # return degrees
    else:
        angles = angles % (2*np.pi) # Bring the angles between 0 and 2pi
        #uw_angles = np.unwrap(angles,axis=axis)
        #stdangle = np.nanstd(uw_angles,axis=axis)
        complex_angles = np.cos(angles) + np.sin(angles)*1j
        stdangle = np.nanstd(np.angle(complex_angles),axis=axis) # return radians
    return stdangle

def clean_gdelay_is(visdata, wave):
    nw = wave.size
    delta = (2*np.pi) * np.mean(1./wave[1:nw-1] - 1./wave[0:nw-2])
    tmp = visdata
    gd = np.angle(np.mean(tmp[1:nw-1] * np.conj(tmp[0:nw-2]))) / delta
    tmp = tmp * np.exp(-2.j*np.pi * gd / wave)
    return tmp

def clean_gdelay_fft(visdata, wave):
    nw = wave.size
    zp = 100
    delta = 1./ (1./wave[nw-1] - 1./wave[0]) / zp
    tmp = visdata
    psd = np.abs(np.fft.fftn(tmp,s=[nw*zp]))
    mnx = np.argmax(psd)
    gd = mnx*delta if mnx<nw*zp/2 else (mnx-nw*zp)*delta
    tmp = tmp * np.exp(-2.j*np.pi * gd / wave)
    return tmp

def clean_gdelay_full(visdata, wave):
    
    n_visdata = visdata.shape[0]
    gdelay_full = []
    
    for frame in range(0,n_visdata):
        # start with two iterations
        tmp = clean_gdelay_is(visdata[frame,:], wave)
        tmp = clean_gdelay_fft(tmp, wave)
    
        # search for the maximum
        x = np.linspace(-10.0, 10.0, num=2000)
        e = np.exp( 2.j*np.pi * np.outer (x, 1./wave))
        P = np.abs(np.mean(e * tmp, axis=1))
        gd = x[ np.argmax(P) ]
        tmp = tmp * np.exp(-2.j*np.pi * gd / wave)

        gdelay_full.append(tmp)
        
    return gdelay_full
       
def clipdata(datalist,minval,maxval):
    array_np = np.nan_to_num(np.asarray(datalist))
    minval = np.nan_to_num(minval)+0
    maxval = np.nan_to_num(maxval)+0
    try:
        low_values_indices = array_np < minval  # Where values are low
        array_np[low_values_indices] = minval  # All low values set to ymin
    except:
        print ('cannot clipdata')
        array_np[:] = minval
    try:
        high_values_indices = array_np > maxval  # Where values are high
        array_np[high_values_indices] = maxval  # All high values set to ymax
    except:
        print ('cannot clipdata')
        array_np[:] = maxval
    return array_np


#==============================================================================
# Reportlab class and functions
#==============================================================================


from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
from reportlab.lib import colors
from reportlab.graphics.shapes import Drawing, String #, Rect, Path
#    from reportlab.graphics.charts.legends import LineLegend
from reportlab.lib.styles import getSampleStyleSheet 
from reportlab.rl_config import defaultPageSize
#from reportlab.lib.pagesizes import A4
from reportlab.lib.units import inch, cm, mm
# from reportlab.pdfgen import canvas
from reportlab.graphics.charts.lineplots import LinePlot, ScatterPlot
#    import cStringIO

# For insertion of matplotlib images (or others)in platypus, through cStringIO
from pdfrw import PdfReader
from pdfrw.buildxobj import pagexobj
from pdfrw.toreportlab import makerl
from reportlab.platypus import Flowable
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_RIGHT

styles = getSampleStyleSheet()

class PdfImage(Flowable):
    """PdfImage wraps the first page from a PDF file as a Flowable
which can be included into a ReportLab Platypus document.
Based on the vectorpdf extension in rst2pdf (http://code.google.com/p/rst2pdf/)"""

    def __init__(self, filename_or_object, width=None, height=None, kind='direct'):
#            from reportlab.lib.units import inch
        # If using StringIO buffer, set pointer to begining
        if hasattr(filename_or_object, 'read'):
            filename_or_object.seek(0)
        page = PdfReader(filename_or_object, decompress=False).pages[0]
        self.xobj = pagexobj(page)
        self.imageWidth = width
        self.imageHeight = height
        x1, y1, x2, y2 = self.xobj.BBox

        self._w, self._h = x2 - x1, y2 - y1
        if not self.imageWidth:
            self.imageWidth = self._w
        if not self.imageHeight:
            self.imageHeight = self._h
        self.__ratio = float(self.imageWidth)/self.imageHeight
        if kind in ['direct','absolute'] or width==None or height==None:
            self.drawWidth = width or self.imageWidth
            self.drawHeight = height or self.imageHeight
        elif kind in ['bound','proportional']:
            factor = min(float(width)/self._w,float(height)/self._h)
            self.drawWidth = self._w*factor
            self.drawHeight = self._h*factor

    def wrap(self, aW, aH):
        return self.drawWidth, self.drawHeight

    def drawOn(self, canv, x, y, _sW=0):
        if _sW > 0 and hasattr(self, 'hAlign'):
            a = self.hAlign
            if a in ('CENTER', 'CENTRE', TA_CENTER):
                x += 0.5*_sW
            elif a in ('RIGHT', TA_RIGHT):
                x += _sW
            elif a not in ('LEFT', TA_LEFT):
                raise ValueError("Bad hAlign value " + str(a))

        xobj = self.xobj
        xobj_name = makerl(canv._doc, xobj)

        xscale = self.drawWidth/self._w
        yscale = self.drawHeight/self._h

        x -= xobj.BBox[0] * xscale
        y -= xobj.BBox[1] * yscale

        canv.saveState()
        canv.translate(x, y)
        canv.scale(xscale, yscale)
        canv.doForm(xobj_name)
        canv.restoreState()
        

def plotTitle(Story,title,spacer=Spacer(1,-2*mm)):
    print (title)
    Story.append(Paragraph(title,  style = styles["Heading2"]))
    if spacer is not None:
          Story.append(spacer)

def plotSubtitle(Story,title,spacer=Spacer(1,5*mm)):
    Story.append(Paragraph(title,  style = styles["Normal"]))
    if spacer is not None:
          Story.append(spacer)
  
def graphoutnoxaxes(data,xmin,xmax,ymin,ymax,sizex,sizey,ticky,title):
    drawing = Drawing(sizex,sizey)
    lp = LinePlot()
    lp.x = 0
    lp.y = 0
    lp.height = sizey
    lp.width = sizex
    lp.data = data
    lp.joinedLines = 1
    lp.lines[0].strokeColor = colors.red
    lp.lines[1].strokeColor = colors.orange
    lp.lines[2].strokeColor = colors.blue
    lp.lines[3].strokeColor = colors.green
    lp.lines[4].strokeColor = colors.cyan
    lp.lines[5].strokeColor = colors.purple
    lp.strokeColor = colors.black
    lp.xValueAxis.valueMin = xmin
    lp.xValueAxis.valueMax = xmax
    lp.xValueAxis.visibleTicks = 0
    lp.xValueAxis.visibleLabels = 0
    #lp.xValueAxis.labelTextFormat = '%.2e'
    #lp.xValueAxis.labelAxisMode = 'wavelength (mic)'
    #lp.xValueAxis.labels.fontSize = 8
    #lp.xValueAxis._text = "Time (s)"
    lp.yValueAxis.valueMin = ymin
    lp.yValueAxis.valueMax = ymax
    lp.yValueAxis.valueStep = ticky if ticky else (ymax-ymin)/5
    lp.yValueAxis.visibleTicks = 1
    lp.yValueAxis.visibleLabels = 1
    lp.yValueAxis.labelTextFormat = '%.2e'
    lp.yValueAxis.labels.fontSize = 8
    lp.xValueAxis.visibleGrid = 1
    lp.yValueAxis.visibleGrid = 1
    # Title of sub plot
    drawing.add(lp)
    drawing.add(String(0.1*sizex, 0.8*sizey, title, fontSize=8))
    return drawing
    
def graphoutaxes(data,xmin,xmax,ymin,ymax,sizex,sizey,ticky,title):
    drawing = Drawing(sizex,sizey)
    lp = LinePlot()
    lp.x = 0
    lp.y = 0
    lp.height = sizey
    lp.width = sizex
    lp.data = data
    lp.joinedLines = 1
    lp.lines[0].strokeColor = colors.red
    lp.lines[1].strokeColor = colors.orange
    lp.lines[2].strokeColor = colors.blue
    lp.lines[3].strokeColor = colors.green
    lp.lines[4].strokeColor = colors.cyan
    lp.lines[5].strokeColor = colors.purple
    lp.strokeColor = colors.black
    lp.xValueAxis.valueMin = xmin
    lp.xValueAxis.valueMax = xmax
    lp.xValueAxis.visibleTicks = 1
    lp.xValueAxis.visibleLabels = 1
    lp.xValueAxis.labelTextFormat = '%.2e'
    #lp.xValueAxis.labelAxisMode = 'wavelength (mic)'
    lp.xValueAxis.labels.fontSize = 8
    #lp.xValueAxis._text = "Time (s)"
    lp.yValueAxis.valueMin = ymin
    lp.yValueAxis.valueMax = ymax
    lp.yValueAxis.valueStep = ticky
    lp.yValueAxis.visibleTicks = 1
    lp.yValueAxis.visibleLabels = 1
    lp.yValueAxis.labelTextFormat = '%.2e'
    lp.yValueAxis.labels.fontSize = 8
    lp.xValueAxis.visibleGrid = 1
    lp.yValueAxis.visibleGrid = 1
    # Title of sub plot
    drawing.add(lp)
    drawing.add(String(0.1*sizex, 0.8*sizey, title, fontSize=8))
    return drawing

def graphoutnoaxis(data,xmin,xmax,ymin,ymax,sizex,sizey,ticky,title):
    drawing = Drawing(sizex,sizey)
    lp = LinePlot()
    lp.x = 0
    lp.y = 0
    lp.height = sizey
    lp.width = sizex
#        # Patch for values outside limits
#        array_np = np.asarray(yvalues)
#        low_values_indices = array_np < ymin  # Where values are low
#        array_np[low_values_indices] = ymin  # All low values set to ymin
#        high_values_indices = array_np > ymax  # Where values are high
#        array_np[high_values_indices] = ymax  # All high values set to ymax
#        data = [tuple(zip(xvalues, array_np) )]
    lp.data = data
    lp.joinedLines = 1
    lp.lines[0].strokeColor = colors.red
    lp.lines[1].strokeColor = colors.orange
    lp.lines[2].strokeColor = colors.blue
    lp.lines[3].strokeColor = colors.green
    lp.lines[4].strokeColor = colors.cyan
    lp.lines[5].strokeColor = colors.purple
    lp.strokeColor = colors.black
    lp.xValueAxis.valueMin = xmin
    lp.xValueAxis.valueMax = xmax
    lp.xValueAxis.visibleTicks = 0
    lp.xValueAxis.visibleLabels = 0
#    lp.xValueAxis.labelTextFormat = '%2.1f'
#    lp.xValueAxis.labels.fontSize = 8
    lp.yValueAxis.valueMin = ymin
    lp.yValueAxis.valueMax = ymax
    lp.yValueAxis.valueStep = ticky
    lp.yValueAxis.visibleTicks = 0
    lp.yValueAxis.visibleLabels = 0
#    lp.xValueAxis.labelTextFormat = '%2.1f'
#    lp.yValueAxis.labels.fontSize = 8
    lp.xValueAxis.visibleGrid = 1
    lp.yValueAxis.visibleGrid = 1
    # Title of sub plot
    drawing.add(lp)
    drawing.add(String(0.1*sizex, 0.8*sizey, title, fontSize=8))
    return drawing

def graphscatteraxes(data,xmin,xmax,ymin,ymax,sizex,sizey,tickx,ticky,title):
    drawing = Drawing(sizex,sizey)
    sp = ScatterPlot()
    sp.x = 0
    sp.y = 0
    sp.height = sizey
    sp.width = sizex
    sp.data = data
    sp.joinedLines = False
    sp.lines[0].symbol.size = 3
    sp.lines[0].strokeColor = colors.red
    sp.lines[0].symbol.kind = 'Circle'
    sp.lines[1].symbol.size = 3
    sp.lines[1].strokeColor = colors.orange
    sp.lines[1].symbol.kind = 'Circle'
    sp.lines[2].symbol.size = 3
    sp.lines[2].strokeColor = colors.blue
    sp.lines[2].symbol.kind = 'Circle'
    sp.lines[3].symbol.size = 3
    sp.lines[3].strokeColor = colors.green
    sp.lines[3].symbol.kind = 'Circle'
    sp.lines[4].symbol.size = 3
    sp.lines[4].strokeColor = colors.cyan
    sp.lines[4].symbol.kind = 'Circle'
    sp.lines[5].symbol.size = 3
    sp.lines[5].strokeColor = colors.purple
    sp.lines[5].symbol.kind = 'Circle'
    sp.lineLabelFormat  = None
    sp.xLabel=''
    sp.yLabel=''
    sp.strokeColor = colors.black
    sp.xValueAxis.valueMin = xmin
    sp.xValueAxis.valueMax = xmax
    sp.xValueAxis.valueStep = tickx
    sp.xValueAxis.visibleTicks = 1
    sp.xValueAxis.visibleLabels = 1
    sp.xValueAxis.labelTextFormat = '%.2e'
    sp.xValueAxis.labels.fontSize = 8
    sp.yValueAxis.valueMin = ymin
    sp.yValueAxis.valueMax = ymax
    sp.yValueAxis.valueStep = ticky
    sp.yValueAxis.visibleTicks = 1
    sp.yValueAxis.visibleLabels = 1
    sp.yValueAxis.labelTextFormat = '%.2e'
    sp.yValueAxis.labels.fontSize = 8
    sp.xValueAxis.visibleGrid = 1
    sp.yValueAxis.visibleGrid = 1
    # Title of sub plot
    drawing.add(sp)
    drawing.add(String(0.1*sizex, 0.8*sizey, title, fontSize=8))
    return drawing
