#! /usr/bin/env python3
# -*- coding: iso-8859-15 -*-

from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from numpy import pi,linspace,sqrt,append,arange,array,zeros,ones,dot,angle,exp,conj,cos,sin
from astropy.io import fits
from astropy.time import Time as astroTime
from sklearn.cluster import MeanShift
import time
from scipy.linalg import pinv as pinv2
from scipy.interpolate import interp1d
from tqdm import tqdm
from datetime import datetime

#%%

def make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star=None,Data_circle=None,figsize=None,samevaxis=True,label_axis_inside=True,show_chi2=False,show_contrast=False,show_amp=False):
    
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])

    Data_star_XY=[]
    Data_circle_XY=[]
    Data_XY=[]
    extent_XY=[]
    extent_XY_all=[]
    for b in range(Ny):
        Data_star_Y=[]
        Data_circle_Y=[]
        Data_Y=[]
        extent_Y=[]
        extent_Y_all=[]
        for npoly in range(Nx2):
            for nr in range(Nx1):
                Data_Y+=[Data[nr][:,:,npoly,b]]
                extent_Y+=[[max(xmap[nr][:,npoly]),min(xmap[nr][:,npoly]),min(ymap[nr][:,npoly]),max(ymap[nr][:,npoly])]]
                extent_Y_all+=[[[max(xmap[nrb][:,npoly]),min(xmap[nrb][:,npoly]),min(ymap[nrb][:,npoly]),max(ymap[nrb][:,npoly])] for nrb in range(Nx1)]]
                if Data_star!=None:
                    Data_star_Y+=[Data_star[nr][:,npoly,b]]
                if Data_circle!=None:
                    Data_circle_Y+=[Data_circle[nr][:,npoly,:]]
        Data_XY+=[Data_Y]
        extent_XY+=[extent_Y]
        extent_XY_all+=[extent_Y_all]
        Data_star_XY+=[Data_star_Y]
        Data_circle_XY+=[Data_circle_Y]

    extent_XY=np.array(extent_XY)
    extent_XY_all=np.array(extent_XY_all)

    Datac=np.concatenate([d.ravel() for d in Data])
    vmax=np.percentile(Datac,99.9)
    vmin=np.percentile(Datac,0.1)
    if show_contrast:
        vmin=np.max([vmin,0])

    Nx_plot=Nx
    Ny_plot=Ny
    Ny_scale=1.
    if (Ny==1)&(Nx==9):
        Ny_plot=Nx_plot=3
        Ny_scale*=3

    if figsize == None:
        figsize=(1+Nx*3,(1+Nx*3)*Ny/Nx)


    fig,axs=plt.subplots(Ny_plot,Nx_plot,num=name_fig,clear=True, figsize=figsize,squeeze=False)
    fig.suptitle(fig.get_label())
    if label_axis_inside:
        fig.subplots_adjust(hspace=0.02,wspace=0.02)

    for i in range(Ny):
        for j in range(Nx):
            ax=axs.ravel()[i*Nx+j]
            image=Data_XY[i][j][::-1,::-1].T
            if samevaxis:
                ims=ax.imshow(image,aspect="auto",extent=extent_XY[i][j],vmin=vmin,vmax=vmax,interpolation='none')
            else:
                ims=ax.imshow(image,aspect="auto",extent=extent_XY[i][j],interpolation='none')
            ax.set_aspect('equal', adjustable='box')
            if label_axis_inside:
                ax.tick_params(axis="y",direction="in", pad=-22)
                ax.tick_params(axis="x",direction="in", pad=-15)
            ax.set_xlim(ax.get_xlim())
            ax.set_ylim(ax.get_ylim())

            if show_chi2:
                props = dict(boxstyle='round', facecolor='wheat', alpha=0.7)
                text_chi2=r"$\delta\chi^2=$%.1f"%(image.max()-image.min())
                ax.text(0.9, 0.95, text_chi2, fontsize=10,va='top',ha='right',transform=ax.transAxes, bbox=props)
                ax.plot(sxy[0],sxy[1],'k*')
            if show_contrast:
                props = dict(boxstyle='round', facecolor='wheat', alpha=0.7)
                text_chi2=r"$contrast=$%.2e"%(image.max()-0)
                ax.text(0.9, 0.95, text_chi2, fontsize=10,va='top',ha='right',transform=ax.transAxes, bbox=props)
                ax.plot(sxy[0],sxy[1],'k*')
            if show_amp:
                props = dict(boxstyle='round', facecolor='wheat', alpha=0.7)
                text_chi2=r"$\eta_{\rm min}=$%.2f"%(image.min())
                ax.text(0.9, 0.95, text_chi2, fontsize=10,va='top',ha='right',transform=ax.transAxes, bbox=props)
                ax.plot(sxy[0],sxy[1],'k*')

            dist=0.04
            extent=extent_XY_all[i,j]
            cubestruct_x=[(e[0]-dist,e[0]-dist,e[1]+dist,e[1]+dist,e[0]-dist) for e in extent]
            cubestruct_y=[(e[2]+dist,e[3]-dist,e[3]-dist,e[2]+dist,e[2]+dist) for e in extent]

            ax.plot(cubestruct_x[1],cubestruct_y[1],'b')
            ax.plot(cubestruct_x[2],cubestruct_y[2],'r')

            if Data_circle!=None:
                ax.plot(Data_circle_XY[i][j][0],Data_circle_XY[i][j][1],'C6o')
            if Data_star!=None:
                ax.plot(Data_star_XY[i][j][0],Data_star_XY[i][j][1],'r*')
            
            if i == 0:
                ax.set_title(X_title[j])
                if j==Nx-1:
                    imcol=ims
            if j == 0:
                ax.set_ylabel(Y_label[i])

    
    fig.tight_layout()

    return fig

def make_plot_vis(name_fig,V,Npoly_list,Next_files,real_only=False):

    if Npoly_list == None:
        V=V[None]
        Npoly_list=[0]

    Nwave=V.shape[-1]

    Ne=len(V[0])
    fig,axs=plt.subplots(1,len(Npoly_list),num=name_fig,clear=True, sharex=True,sharey=True, layout='constrained',figsize=(10,12))
    fig.suptitle(fig.get_label())
    if len(Npoly_list)==1:
        axs=[axs]
    vmax=np.array([np.abs(V.real),np.abs(V.imag)])
    vmax=np.percentile(vmax.ravel(),99.)
    for n in range(len(Npoly_list)):
        if real_only:
            Vplot=np.array([V[n].real]).transpose((2,1,0,3)).reshape((-1,Nwave*1))
        else:
            Vplot=np.array([V[n].real,V[n].imag]).transpose((2,1,0,3)).reshape((-1,Nwave*2))
        ims=axs[n].imshow(Vplot,aspect="auto",vmin=-vmax,vmax=vmax,interpolation='none')
        axs[n].set_xlim(axs[n].get_xlim())
        axs[n].set_ylim(axs[n].get_ylim())
        if len(Npoly_list)>1:
            axs[n].set_title("Poly_order = %i"%(Npoly_list[n]-1))

        axs[n].plot(axs[n].get_xlim(),np.ones((2,1))*np.arange(0,len(Vplot),len(Vplot)//6)[1:]-0.5,'w')
        axs[n].plot(axs[n].get_xlim(),np.ones((2,1))*Next_files+0.5,'k',linewidth=0.3)
            
    fig.colorbar(ims,ax=axs[-1])

    return fig



def get_chi2_swap(visdata_swap,ucoord_swap,vcoord_swap,swap,ext_day,xy,xyrange_map,Nrange):

    Nf=len(visdata_swap)
    Nwave=visdata_swap[0].shape[2]
    Ne=len(np.unique(ext_day))

    chi2_map=np.zeros((Ne,Nrange,Nrange,6))
    visdata_summed=np.zeros((Ne,6,Nwave))

    xmap=np.linspace(-xyrange_map,xyrange_map,Nrange)+xy[0]
    ymap=np.linspace(-xyrange_map,xyrange_map,Nrange)+xy[1]
    for i in tqdm(range(Nrange)):
        for j in range(Nrange):
            for e,ext in enumerate(np.unique(ext_day)):
                visdata_summed=np.zeros((6,Nwave),dtype=np.complex128)
                for k in range(Nf):
                    if ext_day[k]==ext:
                        if swap[k] > 0:
                            phase=xmap[i]*ucoord_swap[k] + ymap[j]*vcoord_swap[k]
                        else:
                            phase=- xmap[i]*ucoord_swap[k] - ymap[j]*vcoord_swap[k]
                        visdata_summed+=(visdata_swap[k]*exp(1j*phase)).sum(axis=0)
                chi2_map[e,i,j]=np.abs(visdata_summed).sum(axis=1)

    return chi2_map,xmap,ymap

def get_xy_from_results(result):
    chi2_total,xmap,ymap,chi2_exp,chi2_base,amp_cleaned,uv_cleaned,contrast_total,contrast_exp,contrast_base = result
    Nrange=len(ymap)
    Npoly=len(chi2_exp[0,0])
    Nexp=len(chi2_exp[0,0,0])
    Nb=len(chi2_base[0,0,0])

    ij_total=np.array(np.unravel_index(chi2_total.reshape((Nrange*Nrange,-1)).argmin(0), (Nrange,Nrange))).reshape((2,Npoly))
    ij_exp=np.array(np.unravel_index(chi2_exp.reshape((Nrange*Nrange,-1)).argmin(0), (Nrange,Nrange))).reshape((2,Npoly,Nexp))
    ij_base=np.array(np.unravel_index(chi2_base.reshape((Nrange*Nrange,-1)).argmin(0), (Nrange,Nrange))).reshape((2,Npoly,Nb))

    xy_total=(np.array([xmap[ij_total[0,n],n] for n in range(Npoly)]),
                np.array([ymap[ij_total[1,n],n] for n in range(Npoly)]))
    xy_base=(np.array([xmap[ij_base[0,n],n] for n in range(Npoly)]),
                np.array([ymap[ij_base[1,n],n] for n in range(Npoly)]))
    xy_exp=(np.array([xmap[ij_exp[0,n],n] for n in range(Npoly)]),
                np.array([ymap[ij_exp[1,n],n] for n in range(Npoly)]))
    return np.array(xy_total),np.array(xy_base),np.array(xy_exp)


def get_chi2_target(visData_cleaned,ucoord,vcoord,ucoord_diff,vcoord_diff,weight,amplitude_reference,M_clean,J_clean,sxy,xy,xyrange_map,Nrange,Npoly_list,index_exp,index_base,visPlanet=False,use_weight2=True,contrast_fixed=None):

    amplitude_reference_sum=amplitude_reference.sum(axis=1)
    Nb = 6
    Ndit = len(weight)//6
    Npoly=len(Npoly_list)
    Nexp=index_exp.max()+1
    Nwave=len(ucoord[0])

    xmap=np.linspace(-xyrange_map,xyrange_map,Nrange)
    ymap=np.linspace(-xyrange_map,xyrange_map,Nrange)

    xmap=xmap[:,None]+xy[0]
    ymap=ymap[:,None]+xy[1]

    if visPlanet:
        xmap=xy[0][None]
        ymap=xy[1][None]
        Nrange=1

    amp_cleaned=np.zeros((Nrange,Nrange,Npoly,Ndit*Nb))
    uv_cleaned=np.zeros((Nrange,Nrange,Npoly,Ndit*Nb))
    contrast_total=np.zeros((Nrange,Nrange,Npoly))
    contrast_exp=np.zeros((Nrange,Nrange,Npoly,Nexp))
    contrast_base=np.zeros((Nrange,Nrange,Npoly,Nb))
    chi2_total=np.zeros((Nrange,Nrange,Npoly))
    chi2_exp=np.zeros((Nrange,Nrange,Npoly,Nexp))
    chi2_base=np.zeros((Nrange,Nrange,Npoly,Nb))
    planet_signal=np.zeros((Npoly,Ndit*Nb,Nwave),dtype=np.complex128)

    chi2_zero=np.abs(visData_cleaned)**2*weight

    for n in range(Npoly):
        print("fitting data with Npoly = %i"%(Npoly_list[n]-1))
        for i in tqdm(range(Nrange)):   
            for j in range(Nrange):
                phase=xmap[i,n]*ucoord + ymap[j,n]*vcoord
                phase_diff=(xmap[i,n]-sxy[:,0,None])*ucoord_diff + (ymap[j,n]-sxy[:,1,None])*vcoord_diff
                uv_cleaned_tmp=np.sinc(phase_diff/(2*np.pi))*0+1
                uv_cleaned[i,j,n]=uv_cleaned_tmp[:,Nwave//2]
                vis_data_planet=amplitude_reference*np.exp(-1j*phase)*uv_cleaned_tmp

                # theta=np.sqrt((xy[0,n]-sxy[:,0])**2+(xy[1,n]-sxy[:,1])**2)
                # flux_inj=getFlux_fast(theta)
                # vis_data_planet*=flux_inj[:,None]
                coefs=np.matmul(M_clean[n],vis_data_planet[:,:,None])
                visData_planet_cleaned=(vis_data_planet[:,:,None]-np.matmul(J_clean[n],coefs))[:,:,0]

                amp_cleaned[i,j,n]=np.abs(visData_planet_cleaned).sum(axis=1)/amplitude_reference_sum
                # amp_cleaned[i,j,n]=amp_cleaned[i,j,n]
                if use_weight2==False:
                    amp_cleaned[i,j,n]=1.0
                weight2=weight*amp_cleaned[i,j,n,:,None]

                # for bt in range(6):
                #     visData_planet_cleaned[bt::6]-=visData_planet_cleaned[bt::6].mean(axis=0)

                if contrast_fixed is None: 
                    W_visData_planet_cleaned=np.conj(weight2*visData_planet_cleaned)
                    numerator=(W_visData_planet_cleaned*visData_cleaned[n]).real.sum(axis=1)
                    denominator=(W_visData_planet_cleaned*visData_planet_cleaned).real.sum(axis=1)
                    contrast_total[i,j,n]=numerator.sum()/denominator.sum()
                else: 
                    contrast_total[i,j,n]=contrast_fixed

                if contrast_total[i,j,n]<0:
                    contrast_total[i,j,n]=0
                planet_signal[n]=contrast_total[i,j,n]*visData_planet_cleaned
                chi2_total[i,j,n]= np.sum(weight2*np.abs(planet_signal[n]-visData_cleaned[n])**2)
                chi2_total[i,j,n]+= np.sum(chi2_zero[n]*(1-amp_cleaned[i,j,n,:,None]))


                for e in range(Nexp):
                    index=(index_exp==e)
                    if contrast_fixed is None: 
                        contrast_exp[i,j,n,e]=numerator[index].sum()/denominator[index].sum()
                    else:
                        contrast_exp[i,j,n,e]=contrast_fixed

                    if contrast_exp[i,j,n,e]<0:
                        contrast_exp[i,j,n,e]=0
                    chi2_exp[i,j,n,e] = np.sum(weight2[index]*np.abs(contrast_exp[i,j,n,e]*visData_planet_cleaned[index]-visData_cleaned[n][index])**2)
                    chi2_exp[i,j,n,e] += np.sum(chi2_zero[n,index]*(1-amp_cleaned[i,j,n,index,None]))

                for b in range(Nb):
                    index=(index_base==b)
                    if contrast_fixed is None: 
                        contrast_base[i,j,n,b]=numerator[index].sum()/denominator[index].sum()
                    else:
                        contrast_base[i,j,n,b]=contrast_fixed
                    if contrast_base[i,j,n,b]<0:
                        contrast_base[i,j,n,b]=0
                    chi2_base[i,j,n,b] = np.sum(weight2[index]*np.abs(contrast_base[i,j,n,b]*visData_planet_cleaned[index]-visData_cleaned[n][index])**2)
                    chi2_base[i,j,n,b] += np.sum(chi2_zero[n,index]*(1-amp_cleaned[i,j,n,index,None]))


    if visPlanet:
        return planet_signal
    
    amp_cleaned=amp_cleaned.reshape((Nrange,Nrange,Npoly,Ndit,Nb)).mean(axis=3)
    uv_cleaned=uv_cleaned.reshape((Nrange,Nrange,Npoly,Ndit,Nb)).mean(axis=3)
    
    return chi2_total,xmap,ymap,chi2_exp,chi2_base,amp_cleaned,uv_cleaned,contrast_total,contrast_exp,contrast_base


    ## START CODE : LOAD DATA


def getFlux(theta,diam=8,w_i=0.315):
    lambda_w=2.2e-6
    D=diam
    x=np.linspace(-D*2,D*2,500)
    y=np.linspace(-D*2,D*2,500)
    r=np.sqrt(x[:,None]**2+y[None,:]**2)
    m=r<=D/2
    try:
        thetal=np.array(theta)[:,None,None]
    except:
        thetal=np.array([theta])[:,None,None]
    # arcseconds
    phase=r*0+(x/lambda_w*thetal*1e-3/(180/np.pi*3600)*2*np.pi)
    w_0=w_i*D
    Field_pup=m*np.exp(1j*phase)
    Field_fiber=np.exp(-1*r**2/(2*w_0**2))/w_0
    Inj=np.abs(np.sum(Field_pup*Field_fiber,axis=(1,2)))**2/np.abs(np.sum(m*Field_fiber))**2
    return Inj


def getFlux_fast(theta):
    lambda_w=2.2e-6
    D=8
    x=np.linspace(-D/1.9,D/1.9,20)
    y=np.linspace(-D/1.9,D/1.9,20)
    r=np.sqrt(x[:,None]**2+y[None,:]**2)
    m=r<=D/2
    thetal=np.array(theta)[:,None,None]
    phase=x/lambda_w*thetal*1e-3/(180/np.pi*3600)*2*np.pi
    Field_pup=m*np.exp(1j*phase)

    w_0=0.315*D
    Field_fiber=np.exp(-1*r**2/(2*w_0**2))
    Inj=np.abs(np.sum(Field_pup*Field_fiber,axis=(1,2)))**2/np.abs(np.sum(m*Field_fiber))**2

    return Inj


def saveFitsSpectrum(filename, wav, contrast, contrastCov, h_add, name = "Unknown", date_obs = "Unknown", mjd = "Unknown", instrument="GRAVITY", resolution = "Unknown"):
    """Save a spectrum and contrast spectrum as a fits file
    Parameters:
      wav: 1d array containing the wavelength grid (in um)
      flux: 1d array containing the spectrum flux (W/m2/um)
      fluxCov: 2d array containing the covariance matrix on the spectrum flux (W/m2/um)^2.
      contrast: 1d array containing the contrast spectrum (unitless)
      contrastCov: 2d array containing covariance matrix on the contrast spectrum (unitless)    
    """
    hdr = fits.Header()
    hdr['INSTRU'] = instrument
    hdr['FACILITY'] = 'ESO VLTI'
    hdr['DATE'] = str(datetime.utcnow())
    hdr['DATE-OBS'] = date_obs
    hdr['MJD-OBS'] = mjd
    hdr['OBJECT'] = name
    hdr['SPECRES'] = resolution        
    hdr.update(h_add)
    hdr['COMMENT'] = "FITS file contains multiple extensions"
    primary_hdu = fits.PrimaryHDU(header = hdr)
    col0= [fits.Column(name='WAVELENGTH', format='1D', unit="um",  array=wav)]
    for p in range(len(contrast)):
        col0+= [fits.Column(name='CONTRAST_%i'%(p+1), format='1D',unit=" - ",  array=contrast[p])]
        if contrastCov is not None:
            col0+= [fits.Column(name='COVARIANCE_CONTRAST_%i'%(p+1), format='%iD'%len(wav), unit=" - ^2",  array=contrastCov[p])]
    secondary_hdu = fits.BinTableHDU.from_columns(col0)
    secondary_hdu.name='SPECTRUM'
    hdul = fits.HDUList([primary_hdu, secondary_hdu])
    hdul.writeto(filename, overwrite = True)
    return None
