# -*- coding: utf-8 -*-
"""
Created on Sat Sep 12 18:02:35 2015

@author: kervella
"""

import matplotlib.pyplot as plt
import numpy as np
import pyfits as fits
import datetime

#import sys
# gravi_visual_class.py import in other directory
#sys.path.insert(0, '../../gravi_visual')
#import gravi_visual_class
import gravi_visual_class2
import gravi_dpfit
import itertools
#import os
import fitsio

        
        
def multigauss(xy, param, retX0=False):
  """
  xy = (x, y), both 1D lists

  param = {'a0x', 'a1x'...  position of gaussians as polynomial
        'sigma':, 'amp'
  }
  """
  x, y = xy
  #print x, y
  tmp = {k[:-1]:param[k] for k in [a for a in list(param.keys()) if a.endswith('x')]}
  x0 = gravi_dpfit.polyN(y, tmp)
  if retX0:
      return x0
  #print x0
  return param['amp']*np.exp(-(x[None,:]-x0[:,None])**2/param['sigma']**2).flatten()

def polyfit2d(x, y, z, order=3):
    ncols = (order + 1)**2
    G = np.zeros((x.size, ncols))
    ij = itertools.product(list(range(order+1)), list(range(order+1)))
    for k, (i,j) in enumerate(ij):
        G[:,k] = x**i * y**j
    m, _, _, _ = np.linalg.lstsq(G, z)
    return m

def polyval2d(x, y, m):
    order = int(np.sqrt(len(m))) - 1
    ij = itertools.product(list(range(order+1)), list(range(order+1)))
    z = np.zeros_like(x)
    for a, (i,j) in zip(m, ij):
        z += a * x**i * y**j
    return z



#==============================================================================
#   Preprocessing of the argon frames
#==============================================================================
def preproc_argon(argon_frames,argon_dark_frames,badpix,profile):

    nwave = wave.nwave_sc
    noutputs = wave.nregion_sc
    startx = wave.startx

    print("startx = ",startx)

    # Median computation of the argon frames and dark (usually 10 frames) and subtraction of the dark

    # For high resolution (series of 10 frames)
    argon_med = np.median(argon_frames,axis=0)
    argon_dark_med = np.median(argon_dark_frames,axis=0)

    # Subtraction of dark    
    argon_med -= argon_dark_med

    # Correction of the bad pixels
    from scipy.ndimage import median_filter
    median_filt = median_filter(argon_med, size=3, mode='nearest')
    badp = np.where(badpix.badpiximage_sc>0)
    argon_med[badp] = median_filt[badp]

    # Correction of the outlier pixels
    badp = np.where(np.abs(argon_med)>10*np.abs(median_filt))
    argon_med[badp] = median_filt[badp]
       
    # Trimmed image
    #argon = np.copy(argon_med[:,:nwave])
    argon = np.copy(argon_med[:,startx:startx+nwave])
    
    # Computation of the individual spectra
    argon_spec = np.zeros((noutputs,nwave))
    for output in range(0,noutputs):
        prof = argon*profile.profile_data[output,:,:]
        argon_spec[output,:] = np.sum(prof,axis=0)
    
    print("Noise RMS level : %.2f ADU" % np.nanstd(argon_dark_med))
    
    return argon_spec
    
#==============================================================================
#   Fit and interpolation of the wavelength scale of the argon frames
#==============================================================================
def gravi_fit_argon(argon_spec,polyorder):
    nwave = argon_spec.shape[1]
    noutputs = argon_spec.shape[0]

    wavescale = np.zeros((noutputs,nwave))
    
    if nwave > 500:   
        fitwidth = 15 # for HR
        # Wavelengths of the argon emission lines for HR (microns)
        line_wave = np.array([#1.982291,
                              1.997118,
                              2.032256,
                              2.062186,
                              #2.065277,
                              #2.073922,
                              #2.081672,
                              2.099184,
                              2.133871,
                              2.154009,
                              2.208321,
                              2.313952,
                              #2.385154,
                              #2.397306
                              ])
    else:
        fitwidth = 4 # for MR
        # Wavelengths of the argon emission lines for MR (microns)
        line_wave = np.array([#1.982291,
                              1.997118,
                              2.032256,
                              2.062186,
                              #2.065277,
                              #2.073922,
                              #2.081672,
                              2.099184,
                              2.133871,
                              2.154009,
                              2.208321,
                              2.313952,
                              2.385154,
                              2.397306
                              ])

    nlines = line_wave.shape[0]

    # Correction of the wavelength scale in glass to go to vaccuum scale
    slope = 0.020 #mic/mic
    #wavestep = wave.wave_sc[int(noutputs/2),int(nwave/2)+1]-wave.wave_sc[int(noutputs/2),int(nwave/2)]
    wavescale = np.nanmean(wave.wave_sc,axis=0)
    for channel in range(0,nwave-1):
#        wavescale[channel] -= channel * wavestep * slope
        wavescale[channel] -= channel * (wavescale[channel+1]-wavescale[channel]) * slope

    # Approximate wavelength scale to identify the lines
    xlines = np.zeros((noutputs,nlines))
    for output in range(0,noutputs):
        for line in range(0,nlines):
            xlines[output,line] = np.min(np.where(wavescale>line_wave[line])) #+startx-zeropt
   
    # Fit of each emission line position using a quadratic model of their position as a function of channel number
    allFits = [] 
    for line in range(0,nlines):
        xmin = int(xlines[0,line]-fitwidth)
        xmax = int(xlines[0,line]+fitwidth)
        # print xmin, xmax
        allFits.append(gravi_dpfit.leastsqFit(multigauss, (np.arange(xmin,xmax),np.arange(noutputs)-noutputs/2.),
              {"amp":np.max(argon_spec[:,xmin:xmax]),"sigma":0.5,"a0x":(xmin+xmax)/2.,"a1x":0.0,"a2x":-0.005},
               argon_spec[:,xmin:xmax].flatten(),verbose=0))
        allFits[-1]['X0'] = multigauss((np.arange(xmin,xmax),np.arange(noutputs)-noutputs/2.), allFits[-1]['best'], retX0=True)
        
    plt.figure(figsize=(10, 7)) # overplot of fits on the emission line image
    # plt.imshow(argon_spec,vmin = 0, vmax = 150, aspect='auto', cmap='CMRmap', interpolation='none')
    plt.imshow(argon_spec, vmin =  0, vmax = 100, aspect='auto', cmap='CMRmap', interpolation='none')
    plt.xlabel('Pixel')
    plt.ylabel('Output')
    plt.colorbar()
    for f in allFits:
        plt.plot(f['X0'], list(range(noutputs)),'o',mec='g',mew=1,markersize=6,mfc='None')
    plt.title('Argon lamp exposure and fitted line positions (circles)')

    plt.figure(figsize=(10, 7))
    plt.clf()
    plt.subplot(211)
    plt.errorbar([f['best']['a0x'] for f in allFits], [f['best']['a1x'] for f in allFits], 
                xerr=[f['uncer']['a0x'] for f in allFits], yerr=[f['uncer']['a1x'] for f in allFits],
            fmt='ok',color='r')
    plt.xlabel('X position (pix)')
    plt.ylabel('Slope (deg1)')
    plt.subplot(212)
    plt.errorbar([f['best']['a0x'] for f in allFits], [f['best']['a2x'] for f in allFits], 
                xerr=[f['uncer']['a0x'] for f in allFits], yerr=[f['uncer']['a2x'] for f in allFits],
            fmt='ok',color='r')
    plt.xlabel('X position (pix)')
    plt.ylabel('Curvature (deg2)')

    xfit = np.zeros((noutputs,nlines))
    yfit = np.zeros((noutputs,nlines))
    zfit = np.zeros((noutputs,nlines))

    for output in range(0,noutputs):
        for line in range(0,nlines):
            xfit[output,line]=allFits[line]['X0'][output]   # coordinate of emission line
            yfit[output,line]=output                        # number of output
            zfit[output,line]=line_wave[line]               # value of wavelength
   
    xfit=xfit.flatten()
    yfit=yfit.flatten()
    zfit=zfit.flatten()
    # Fit a 'polyorder' order, 2d polynomial
    m = polyfit2d(xfit,yfit,zfit,order=polyorder)

    print("Polynomial coefficients:\n", m)

    # Evaluate it on a grid...
    xx, yy = np.meshgrid(np.linspace(0, nwave, nwave), 
                         np.linspace(0, noutputs, noutputs))

    # Interpolated wavescale
    wavescale = polyval2d(xx, yy, m)
    print("Min wavelength: %.3f microns"%np.min(wavescale))
    print("Max wavelength: %.3f microns"%np.max(wavescale))
    # Residuals of the fit
    residuals = np.zeros((zfit.shape[0]))
    for i in range(0,zfit.shape[0]):
        residuals[i] = zfit[i] - polyval2d(xfit[i],yfit[i], m)
    print("Residual RMS of the 2nd fit : %.3f nm" % (np.std(residuals)*1000))
    print("Peak residual of the 2nd fit: %.3f nm" % (np.max(np.abs(residuals))*1000))

    # print residuals
   
#    # Plots
#    plt.figure(figsize=(10, 7))
#    plt.imshow(wavescale, aspect='auto')
#    plt.scatter(xfit.flatten(), yfit.flatten(), c=zfit.flatten())
#    plt.xlabel('Pixel')
#    plt.ylabel('Output')
#    plt.colorbar()
#    plt.show()
    
    return wavescale


def gravi_fit_argon_mr(x0, x1, nwave):
    filename = '2015-09-09_Argon_MedPol.fits'
    darkname = '2015-09-09_Argon_MedPol_dark.fits'
    
    rawimage = fits.open(filename)[0].data
    dark = fits.open(darkname)[0].data
    
    image = np.array(rawimage - dark)
#    plt.close('all')
#    plt.figure(figsize=(10, 7))
#    plt.imshow(image,vmin=np.percentile(image,3),vmax = np.percentile(image,99),interpolation='None',aspect='auto')
#    plt.set_cmap('cubehelix')
#    plt.colorbar()
    
    # Extraction of the spectra
    noutputs = 48
    startx = x0 + x1
    starty = 3
    stepy = 7
    dx = 1 # half width of the centroiding around line
    
    spectrum = np.zeros((noutputs,nwave))
    #print spectrum.shape
    #print image.shape
    
    for output in range(0,noutputs):
        for wave in range(0,nwave):
            spectrum[output,wave] = np.sum(image[(stepy*output+starty-1):(stepy*output+starty+1),startx+wave],axis=0)
    
    # Normalization of spectra
    for output in range(0,noutputs):
        spectrum[output,:] /= np.mean(spectrum[output,51:54])
    
#    plt.figure(figsize=(10, 7))
#    plt.imshow(spectrum,vmin=np.percentile(spectrum,3),vmax = np.percentile(spectrum,99),aspect='auto',interpolation='None')
    
    line_x     = np.array([#3,
                           9,
                           25,
                           39,
                           44,
                           # 48,
                           56,
                           71,
                           80,
                           105,
                           153,
                           185,
                           190])+(67-x1) # x coordinates in the image
    wavelength = np.array([#1.982291,
                           1.997118,
                           2.032256,
                           2.062186,
                           2.073922,
                           # 2.081672,
                           2.099184,
                           2.133871,
                           2.154009,
                           2.208321,
                           2.313952,
                           2.385154,
                           2.397306]) # wavelength
    nlines = line_x.shape[0]
    centroid = np.zeros((noutputs,nlines))
    for output in range(0,noutputs):
        for line in range(0,nlines):
            x = line_x[line]
            subspec = spectrum[output,x-dx:x+dx+1]
            deltax = list(range(x-dx, x+dx+1))
            centroid[output,line] = np.average(deltax,weights=subspec)
    #
    #plt.figure(figsize=(10, 7))
    #plt.imshow(centroid,vmin=np.percentile(centroid,3),vmax = np.percentile(centroid,99),aspect='auto',interpolation='None')
    
    #from scipy.interpolate import interp1d
    
    wavescale = np.zeros((noutputs,nwave))
    residual = np.zeros((noutputs,nlines))
    
    for output in range(0,noutputs):
        # re-interpolation at the frequency of the metrology
        # scale = interp1d(centroid[output,:],wavelength[:],kind='quadratic',bounds_error=False, fill_value=0.0)
        poly = np.polyfit(centroid[output,:],wavelength[:],2)
        # print poly
        for wave in range(0,nwave):
            wavescale[output,wave] = poly[0]*wave**2 + poly[1]*wave + poly[2]
    
        for line in range(0,nlines):
            residual[output,line] = poly[0]*centroid[output,line]**2 + poly[1]*centroid[output,line] + poly[2]-wavelength[line]
    
#    plt.figure(figsize=(10, 7))
#    plt.imshow(wavescale,vmin=1.9,vmax =2.5,aspect='auto',interpolation='None')
#    
#    plt.figure(figsize=(10, 7))
#    plt.imshow(residual,vmin=-10**-3,vmax =10**-3,aspect='auto',interpolation='None')
#    plt.colorbar()
    
    print(" RMS residual of wavelength scale fit to argon: %.6f microns" % np.std(residual))
    
    return wavescale
    
    



    
    
    
#==============================================================================
# MAIN PROGRAM    
#==============================================================================

if __name__ == '__main__':

    plt.close('all')

    # Medium resolution:
    badpix = gravi_visual_class2.Badpix("GRAVI.2016-01-16T11:00:41.244_badpix.fits")
    profile = gravi_visual_class2.Profile("GRAVI.2016-01-16T11:03:23.687_flat.fits")
    wave = gravi_visual_class2.Wave("GRAVI.2016-01-16T11:04:47.649_wave.fits")
    argondir = "argon-mr/"
    argonfiles = np.array(["GRAVI.2015-12-08T14:25:54.732",
                  "GRAVI.2016-01-12T09:48:49.688",
                  "GRAVI.2016-01-13T08:11:41.830",
                  "GRAVI.2016-01-15T10:05:14.017",
                  #"GRAVI.2016-01-16T10:01:15.111", # bad
                  "GRAVI.2016-01-17T12:26:59.349",
                  "GRAVI.2016-01-18T10:06:20.902",
                  "GRAVI.2016-01-19T10:07:37.577",
                  "GRAVI.2016-01-20T10:01:57.546",
                  "GRAVI.2016-01-21T10:34:08.250"])
    darkfiles = np.array(["GRAVI.2015-12-08T14:36:29.247",
                 "GRAVI.2016-01-12T09:59:27.203",
                 "GRAVI.2016-01-13T08:22:19.345",
                 "GRAVI.2016-01-15T10:15:51.531",
                 #"GRAVI.2016-01-16T10:11:52.626", # bad
                 "GRAVI.2016-01-17T12:37:36.863",
                 "GRAVI.2016-01-18T10:16:58.417",
                 "GRAVI.2016-01-19T10:18:15.093",
                 "GRAVI.2016-01-20T10:12:35.061",
                 "GRAVI.2016-01-21T10:44:45.764"])

#    # High resolution:
#    badpix = gravi_visual_class2.Badpix("GRAVI.2015-12-08T15:56:45.982_badpix.fits")
#    profile = gravi_visual_class2.Profile("GRAVI.2015-12-08T15:58:15.988_flat.fits")
#    wave = gravi_visual_class2.Wave("GRAVI.2015-12-08T16:00:48.996_wave.fits")
#    argondir = "argon-hr/"
#    argonfiles = np.array(["GRAVI.2015-12-06T15:03:19.274",
#                           "GRAVI.2016-01-17T10:46:38.463"])
#    darkfiles = np.array(["GRAVI.2015-12-06T15:24:01.344",
#                          "GRAVI.2016-01-17T11:07:26.112"])


    wavescale = []
    argon_spec = []
    date_obs = []
    polyorder = 2 # Order of the polynome for the wavelength fit
    
    for i in range(0,argonfiles.shape[0]):
        print("Processing file ", argonfiles[i])
        gravi_argon_frame = fits.open(argondir+argonfiles[i]+'.fits')
        gravi_dark_frame = fits.open(argondir+darkfiles[i]+'.fits')

        argon_frames = gravi_argon_frame[5].data
        argon_dark_frames = gravi_dark_frame[5].data
        date_obs.append(gravi_argon_frame[0].header['MJD-OBS'])
        
        argon_spec.append(preproc_argon(argon_frames,argon_dark_frames,badpix,profile))
        fitsio.write(argonfiles[i]+"-preproc.fits", argon_spec[i])

        plt.figure(figsize=(10, 7))
        plt.title('Dark subtracted argon lamp exposure')
        plt.imshow(argon_spec[i],vmin = 0, vmax = 400, aspect='auto', cmap='CMRmap', interpolation='none')
        plt.xlabel('Pixel')
        plt.ylabel('Output')
        plt.colorbar()
    
        plt.figure(figsize=(10, 7))
        plt.title('Profiles of the argon emission spectra')
        for output in range(0,argon_spec[i].shape[0]):
            plt.plot(argon_spec[i][output,:])
        plt.xlabel('Pixel')
        plt.ylabel('Flux (ADU)')
        
        wavescale.append(gravi_fit_argon(argon_spec[i],polyorder))
#        saveArgonScale(wavescale[i],profile,argonfiles[i]+"-MRWavescale")
        fitsio.write(argonfiles[i]+"-WaveImage.fits", wavescale[i])

    pixel = 2.2 # pixel size in nm for pixel offset display
    
    avgres = []
    for i in range(0,argonfiles.shape[0]):
        avgres.append(np.mean(wavescale[i] - np.mean(wavescale))*1000.)
        print("Avg diff. "+argonfiles[i]+" - mean = %(residual)+.4f nm = %(respix)+.4f pixels" % {"residual":avgres[i],"respix":avgres[i]/pixel})

    plt.figure(figsize=(10, 7))
    plt.title('Relative argon wavelength offsets MEDIUM-COMBINED (wrt mean)')
    plt.plot(date_obs,avgres,marker='o')
    plt.xlabel("MJD")
    plt.ylabel("Offset (nm)")
    axes = plt.gca()
    axes.set_ylim([-0.2,0.2])
    plt.grid()

    rmsres = np.std(avgres)
    print("RMS over period = %(rmsres).4f nm = %(rmspix).4f pixels" % {"rmsres":rmsres,"rmspix":rmsres/pixel})







#def saveArgonScale(wavescale,profile,filename):
##==============================================================================
## Write the resulting wavelength tables to a FITS table file
##==============================================================================
#    print "*** Saving the resulting wavelength tables in FITS file"
#    
#    noutputs = wavescale.shape[0]
##    nwave = wavescale.shape[1]
#    now = datetime.datetime.now()
#    prihdr = fits.Header()
#    prihdr['DATE'] = now.strftime("%Y-%m-%dT%H:%M:%S")
#    prihdr['COMMENT'] = "GRAVITY Argon wavelength scale from the testbed python code."
#    prihdu = fits.PrimaryHDU(header=prihdr)
#    hdulist = fits.HDUList(prihdu)
#    
#    # Wavelength table with one coumn per output
#    col_sc = []
#    for output in range(0,noutputs): # DETECTOR order
#        col_sc.append(fits.Column(name=("%2i"%output), format='E', unit='microns', array=wavescale[output,:]))
#    cols_sc = fits.ColDefs(col_sc)
#    tbhdu = fits.BinTableHDU.from_columns(cols_sc,name='ARGONSCALE_SC')
#    hdulist.append(tbhdu)
#
##    # Synthetic image of pixel wavelengths on the detector
##    # Computation of the individual spectra
##    prof = np.zeros((noutputs,nwave,2))
##    for output in range(0,noutputs):
##        for wave in range(0,nwave):
##            prof[:,:,0] += wavescale[output,wave]*np.where(profile.profile_data[output][:,wave]>0)
##            prof[:,:,1] += profile.profile_data[output]
##    imghdu = fits.ImageHDU(data=prof,name=("TEST_WAVE"))
##    hdulist.append(imghdu)
#       
#    hdulist.writeto(filename+'.fits',clobber=True)
