#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Plot incoming OPD PSD for a GRAVITY raw file
Created on Tue May  2 03:39:17 2017

@author: slacour
"""

class DATAOI:
    def __init__(self, file):
        
        M_matrix = array([-1.,1.,0.0,0.0,-1.,0.0,1.,0.0,-1.,0.0,0.0,1.,0.0,-1.,1.,0.0,0.0,-1.,0.0,1.,0.0,0.0,-1.,1.]);
        M_matrix = M_matrix.reshape((6,4))
        self.time=getdata(file,'OPDC').field('TIME')
        self.piezo=getdata(file,'OPDC').field('PIEZO_DL_OFFSET')
        self.piezoM=dot(M_matrix,self.piezo.T).T
        self.kpiezo=getdata(file,'OPDC').field('KALMAN_PIEZO')
        self.kpiezoM=dot(M_matrix,self.kpiezo.T).T
        self.opd=getdata(file,'OPDC').field('OPD')
        self.kopd=getdata(file,'OPDC').field('KALMAN_OPD')
        self.katm=self.kopd-dot(M_matrix,self.kpiezo.T).T
        self.FTname=getheader(file)["HIERARCH ESO FT ROBJ NAME"]
        print((self.FTname))
        self.SCname=getheader(file)["HIERARCH ESO INS SOBJ NAME"]
        print((self.SCname))
        self.dpr=getheader(file)["HIERARCH ESO DPR TYPE"]
        print((self.dpr))
        self.dit=getheader(file)["HIERARCH ESO DET3 SEQ1 DIT"]
        self.rate=getheader(file)["HIERARCH ESO FT RATE"]
        print((self.rate))

from matplotlib.pyplot import *
from numpy import *
from astropy.io.fits import *
from glob import glob
from scipy.linalg import pinv2
import argparse

import matplotlib as mpl
mpl.style.use('classic')

parser = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     prefix_chars='-+')
parser.add_argument('fname', nargs='*')

parser.add_argument("-N", "--Nsmooth", help=
                    "Smoothing width [2000]",
                    type=int, default=2000)

parser.add_argument("-n", "--Ns", help=
                    "for plotting",
                    type=int, default=100)

def compute_ffts(data, Nsmooth=2000):
    O=array([unwrap(data.opd[:,i])-convolve(ones(Nsmooth)/Nsmooth,unwrap(data.opd[:,i]),'same') for i in range(6)]).T
    P=array([data.kpiezoM[:,i]-convolve(ones(Nsmooth)/Nsmooth,data.kpiezoM[:,i],'same') for i in range(6)]).T
    
    A=O-P 
    
    ft_P=abs(fft.fft(P.T,axis=1).T)
    ft_O=abs(fft.fft(O.T,axis=1).T)
    ft_A=abs(fft.fft(A.T,axis=1).T)
    freq=fft.fftfreq(len(data.time), median(diff(data.time))*1e-6)

    return (freq, ft_P, ft_O, ft_A)

def plot_for_baseline(ax, freq, ft_P, ft_O, ft_A, base, Ns=100):
    ax.plot(freq,convolve(ones(Ns)/Ns,ft_A[:,base],'same'),label="FFT Atmosphere")
    ax.plot(freq,convolve(ones(Ns)/Ns,ft_P[:,base],'same'),label="FFT Piezo")
    ax.plot(freq,convolve(ones(Ns)/Ns,ft_O[:,base],'same'),label="FFT OPD residuals")


def plot_all_baselines(freq, ft_P, ft_O, ft_A, Ns=100):
    fig=figure(figsize=(16, 12))
    ax1=subplot(3, 2, 1)
    plot_for_baseline(ax1, freq, ft_P, ft_O, ft_A, 0, 100)
    ax1.set_yscale("log")
    ax1.legend(fontsize='x-small', fancybox=True, framealpha=0.5)
    ax1.set_xlim(0, freq.max())
    for base in [1, 2, 3, 4, 5]:
        ax=subplot(3, 2, base+1, sharex=ax1, sharey=ax1)
        plot_for_baseline(ax, freq, ft_P, ft_O, ft_A, base, Ns)
    xlabel("Frequency")
    ax1.set_xlim(0, 250)
    show()

def process_file(fname, Nsmooth=2000, Ns=100):
    data=DATAOI(fname)
    freq, ft_P, ft_O, ft_A = compute_ffts(data, Nsmooth)
    plot_all_baselines(freq, ft_P, ft_O, ft_A, Ns)

if (__name__ == "__main__"):
    args = parser.parse_args()
    process_file(args.fname[0], args.Nsmooth, args.Ns)

