#! /usr/bin/env python3
# -*- coding: iso-8859-15 -*-

#%%==============================================================================
# CODE TO REDUCE THE ASTRORED TO GET ASTROMETRY
#==============================================================================

# from asyncio.format_helpers import _format_callback_source
import os
import sys
import time
import matplotlib
if "VSCODE_PID" in os.environ:
    matplotlib.use('Qt5Agg')  # To use to show plots
else:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import plot,hist,clf,figure,savefig,title,xlabel,ylabel,legend,xlim,ylim,imshow,colorbar
plt.ion()
from matplotlib.backends.backend_pdf import PdfPages
from astropy.io import fits
from glob import glob
import numpy as np
from numpy import pi,linspace,sqrt,append,arange,array,zeros,ones,dot,angle,exp,conj,cos,sin
from scipy.linalg import pinv
from scipy.optimize import leastsq
from scipy.interpolate import interp1d
from sklearn.cluster import MeanShift
from gravi_astrored import astrored_loadData as al
from gravi_astrored import astrored_astroFit as af
from scipy.linalg import pinv
from astropy.time import Time as astroTime
np.seterr(divide='ignore', invalid='ignore')

import argparse
from argparse import Namespace

from tqdm import tqdm
from datetime import datetime

#%%
usage = """
description:
    Reduce GRAVITY astroreduced data for exoplanet detection
"""

examples = """
examples:
    Get help:
    run_gravi_astrored_astrometry.py -h

    Run slowly:
    run_gravi_astrored_astrometry.py --slow=TRUE

    Use short baselines:
    run_gravi_astrored_astrometry.py --shortB=TRUE

    Do not include part of the spectra where there is tellurics lines:
    run_gravi_astrored_astrometry.py --removeTelluric=TRUE

    Computed chi2 map on reduced size (mas):
    run_gravi_astrored_astrometry.py --range=50

    Fix the contrast to a given value:
    run_gravi_astrored_astrometry.py --contrast=0.5

    Computed chi2 map around position (RA, DEC):
    run_gravi_astrored_astrometry.py --RA=50 --DEC=1000

    Redo the reduction:
    run_gravi_astrored_astrometry.py --reDo=TRUE
"""

#
# Implement options
#

parser = argparse.ArgumentParser(description=usage, epilog=examples,conflict_handler="resolve",
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
TrueFalse = ['TRUE', 'FALSE']

parser.add_argument("--fromDir", dest="fromDir", default='TRUE', 
                    help="read data from directory [TRUE]")

parser.add_argument("--slow", dest="slow", default='FALSE', choices=TrueFalse,
                    help="Run slowly [FALSE]")

parser.add_argument("--shortB", dest="shortB", default='FALSE', choices=TrueFalse,
                    help="also using short baselines [FALSE]")

parser.add_argument("--removeTelluric", dest="removeTelluric", default='FALSE', choices=TrueFalse,
                    help="Do not include the part of the spectra where there are tellurics lines [FALSE]")

parser.add_argument("--range", dest="range", default=50, type=int,
                    help="Use chi2 map of size [50] in mas")

parser.add_argument("--contrast", dest="contrast", default=0, type=float,
                    help="fix the contrast to a given value [0]")

parser.add_argument("--RA", dest="RA", default=0, type=int,
                    help="Search around position [mas]: RA")

parser.add_argument("--DEC", dest="DEC", default=0, type=int,
                    help="Search around position [mas]: DEC")

parser.add_argument("--reDo", dest="reDo", default='TRUE', choices=TrueFalse,
                    help="redo the reduction [TRUE]")


argoptions, recipes_args = parser.parse_known_args()


#%%
if argoptions.slow == 'TRUE':
    print("Running processing in slow mode...")
    fast = False
else:
    fast = True

search_within=argoptions.range

if argoptions.shortB == 'TRUE':
    print("We are not weighting down the short baselines")
    use_weight2 = False
else:
    use_weight2 = True


if argoptions.removeTelluric == 'TRUE':
    print("We are cutting the wavelength to remove telluric signal")
    removeTelluric = True
else:
    removeTelluric = False

if argoptions.contrast > 0:
    contrast_fixed = argoptions.contrast
    print("We are using a fix value for the contrast of %f"%contrast_fixed)
else:
    contrast_fixed = None

if ((argoptions.RA!=0)&(argoptions.DEC!=0)):
    search_at_xy=[argoptions.RA,argoptions.DEC]
    print("We are searching aroung value ",search_at_xy)
else:
    search_at_xy=None

if argoptions.reDo == 'FALSE':
    doNotRedo=True
else:
    doNotRedo=False

fromDir=True

if len(argoptions.fromDir.split("/"))>2:
    fromDir=False
else:
    print("running from command line")

# print(sys.argv)
# filelist=[]
# ## If the user specifies a file name or wild cards ("*_0001.fits")
# if len(sys.argv) > 1 :
#     longnames = [f for files in sys.argv[1:] for f in glob(files)]
#     filelist = [os.path.splitext(f)[0] for f in longnames]
# ## Processing of the full current directory
# else :
filelist=[]
print("In directory :",os.getcwd())
for file in os.listdir("."):
    if file.endswith("_astroreduced.fits"):
        filelist.append(os.path.splitext(file)[0])
        # print(filelist)

filelist_dir=filelist.sort() # process the files in alphabetical order


files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/51Eri/2023-11-25/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/AFlep/2023-12-25"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/AFlep/2023-11-02"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/BetaPictoris/2020-03-08"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2021-08-27"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/WT766/2022-08-18/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/CD-50_869/2022-07-22"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD100546/2021-01-06"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP99770/2023-05-31"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP99770/2023-07-02"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/ABAur/2023-01-07"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP64892/2024-02-11"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/WDSJ05055+1948/2021-12-28"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/YSES_2/2023-05-10/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP99770/2023-05-31/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP99770/2023-07-02/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HR2562/2024-03-19"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/Elias2_24/2024-04-29"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/TW_Hya/2024-04-27"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP79098/2022-09-05"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/TW_Hya/2024-04-27/"
# files_dir='/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/StKM_1-1494/2023-05-08/'
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD72946/2020-02-08/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD984/2021-08-27"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2024-06-01"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/Gl229/2024-03-29/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2024-05-31/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/LP754-5/2024-06-03/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2021-09-27"
# files_dir='/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD95086/2020-02-10/'
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2024-06-XX/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2023-07-02/"
# files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2022-05-22/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/BetaPictoris/2020-03-08"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/BetaPictoris/2022-08-18"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD206893/2024-06-30/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/Mu2Sco/2023-07-02/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD135344B/2024-04-27/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HD135344B/2024-06-02/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/TW_Hya/2025-05-28/"
files_dir="/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/PDS70/2021-01-07/"
files_dir='/Users/slacour/Nextcloud/REDUCED_DATA/reduced_2023-07-12/HIP99770/2023-07-02/'
# fast=False
# fromDir=False
# search_at_xy=[45,-57]
# search_at_xy=[13,-57]
# removeTelluric = True
# search_within=5
# # fast=True
# contrast_fixed=9e-5
# use_weight2=False

#%%


if "VSCODE_PID" in os.environ:
    print("loading data from remote directory")
    files=glob(files_dir+"/*_astroreduced.fits")
else:
    print("loading data from current directory")
    files=[f+".fits" for f in filelist]


files=np.sort(files)#[:-1]#[8:11:2]

if len(files) == 0:
    raise ValueError('No good files !!!')

dataset=[]
for file in files:
    fheader=fits.getheader(file)
    if np.any(np.array(al.badfilesList)==fheader["DATE-OBS"]):
        print("Rejecting file :"+file)
    else:   
        print("Reading file :"+file)
        if fheader['HIERARCH ESO INS POLA MODE'] == "COMBINED":
            dataset+=[al.GravDataSet(file,fheader,10)]
        else:
            dataset+=[al.GravDataSet(file,fheader,11)]
            dataset+=[al.GravDataSet(file,fheader,12)]

for d in dataset:
    d.get_xy_and_comment(True)

for d in dataset:
    d.read_oi_vis(remove_telluric=removeTelluric)


UT_obs = dataset[0].h_v["UT"]=="UT1"

uvmax = np.max([d.uvmax for d in dataset])


#%%
# Parameters


if UT_obs:
    xyrange_map_list_target=[70,10,uvmax]

    for i in range(len(xyrange_map_list_target)):
        if xyrange_map_list_target[i]>search_within:
            xyrange_map_list_target[i] = search_within

    Nrange_list_target=np.ones(3,dtype=np.int64)
    Nrange_list_target[0]=int(2*xyrange_map_list_target[0]/uvmax*4)
    Nrange_list_target[1]=int(2*xyrange_map_list_target[1]/uvmax*8)
    Nrange_list_target[2]=int(2*xyrange_map_list_target[2]/0.02)
else:
    xyrange_map_list_target=[200,40,uvmax]

    for i in range(len(xyrange_map_list_target)):
        if xyrange_map_list_target[i]>search_within:
            xyrange_map_list_target[i] = search_within

    Nrange_list_target=np.ones(3,dtype=np.int64)
    Nrange_list_target[0]=int(2*xyrange_map_list_target[0]/uvmax*1.5)
    Nrange_list_target[1]=int(2*xyrange_map_list_target[1]/uvmax*8)
    Nrange_list_target[2]=int(2*xyrange_map_list_target[2]/0.02)

if fast:
    Nrange_list_target[0]=int(Nrange_list_target[0]/1.5)
    Nrange_list_target[1]=int(Nrange_list_target[1]/1.5)
    Nrange_list_target[2]=int(Nrange_list_target[2]/2.5)


xyrange_map_list_swap=xyrange_map_list_target
Nrange_list_swap=Nrange_list_target.copy()

# Nrange_list_target[2]/=2
# xyrange_map_list_swap=np.array(xyrange_map_list_swap)//2
# Nrange_list_swap=Nrange_list_swap//2

make_spectra = True

data_cluster_size = 16

if not UT_obs:
    data_cluster_size*=8/1.8

#%%


for d in dataset[:]:
    if d.centred == 0:
        if fast:
            d.compact_data(full=True)
        else:
            d.compact_data()


#%%

plt.close('all')

names=np.array([d.h_v['NAME'] for d in dataset])
name_unique=np.unique(names)
number=np.array([np.sum(n==names) for n in name_unique])
Name_object=name_unique[number.argmax()].replace(" ","")
Date=dataset[0].h_v["DATE"][:10]

name_files="AstroL_"+Name_object+"_"+Date

if use_weight2==False:
    name_files+="_shortB"

if fast == True:
    name_files+="_fast"
else:
    name_files+="_slow"


dataset[0].name_files=name_files
print(name_files)


sxyFTSCName=np.array([d.sxyFTSCName for d in dataset])
sxySCFTName=np.array([d.sxySCFTName for d in dataset])

sxyUnique=np.unique(sxyFTSCName)
GoodSwap=np.zeros(len(sxyFTSCName),dtype=int)
Swap_todo=0
name=[]
for n in sxyUnique:
    if (np.sum(n==sxySCFTName)>0):
        print(n)
        Swap_todo+=1
        Good1=n==sxyFTSCName
        Good2=n==sxySCFTName
        mjd1=[]
        mjd2=[]
        for d in dataset:
            if n==d.sxyFTSCName:
                mjd1+=[d.mjd]
            if n==d.sxySCFTName:
                mjd2+=[d.mjd]
        inv=1
        if np.min(mjd2) < np.min(mjd1):
            inv=-1
        for d in dataset:
            if n==d.sxyFTSCName:
                d.swap=Swap_todo*inv
            if n==d.sxySCFTName:
                d.swap=Swap_todo*-1*inv
                mjd2+=[d.mjd]
        sxySCFTName[n==sxyFTSCName]='Done'
        sxySCFTName[n==sxySCFTName]='Done'
        
Target_todo=0
res=np.array([d.h_v["RES"] for d in dataset])
axis=np.array([d.h_v["AXIS"] for d in dataset])
pola=np.array([d.h_v["POLA"] for d in dataset])
ut=np.array([d.h_v["UT"] for d in dataset])
centred=np.array([d.centred for d in dataset])
swap=np.array([d.swap for d in dataset])
sxy=np.array([d.sxy for d in dataset])
ext_day=np.array([d.ext_day for d in dataset])

GoodZero=centred
GoodSwap=(swap!=0)
GoodSxy=(~GoodZero)&(~GoodSwap)

ms = MeanShift(bandwidth=data_cluster_size, bin_seeding=True)
ms.fit(sxy[GoodSxy])
cluster_centers = ms.cluster_centers_
clusters_good= ms.labels_

if search_at_xy is not None:
    cluster_centers=[search_at_xy]
    sxy_t=np.array(sxy[GoodSxy])-search_at_xy
    clusters_good = np.sqrt(sxy_t[:,0]**2+sxy_t[:,1]**2) > 40

for lab,xy in enumerate(cluster_centers):
    GoodExpN=np.arange(len(GoodSxy))[GoodSxy][clusters_good == lab]
    Target_todo+=1

    for g in GoodExpN:
        e=ext_day[g]
        resolution = res[g]
        axis_onoff = axis[g]
        polarisation = pola[g]
        Target_value = Target_todo

        GoodExtension=(e==ext_day)
        GoodResolution=(res==resolution)
        GoodAxis=(axis==axis_onoff)
        GoodPola=(pola==polarisation)
        GoodRefSwap=GoodSwap&(~GoodZero)&GoodResolution&GoodAxis&GoodExtension
        GoodRefZeroAmp=GoodZero&GoodResolution&GoodPola&GoodExtension
        GoodRefZeroPhase=GoodZero&GoodResolution&GoodAxis&GoodPola&GoodExtension
        
        mode_Phase='none'
        if np.sum(GoodRefSwap)>0:
            mode_Phase="off-axis"
            Ref_Phase=GoodRefSwap
        elif np.sum(GoodRefZeroPhase)>0:
            mode_Phase="on-axis"
            Ref_Phase=GoodRefZeroPhase
        else:
            print('no phase reference for file n%i'%g)
            Target_value = 0

        mode_Amp='none'
        if np.sum(GoodRefZeroAmp)>0:
            Ref_Amp=GoodRefZeroAmp
            mode_Amp="on-axis"
        elif np.sum(GoodRefSwap)>0:
            Ref_Amp=GoodRefSwap
            mode_Amp="off-axis"
        else:
            print('no amplitude reference for file n%i'%g)
            Target_value = 0


        dataset[g].target=Target_todo
        dataset[g].Ref_Phase=np.where(Ref_Phase)[0]
        dataset[g].Ref_Amp=np.where(Ref_Amp)[0]
        dataset[g].xy_search=np.array(xy)
        dataset[g].mode_Amp=mode_Amp
        dataset[g].mode_Phase=mode_Phase
        

GoodExpN=np.arange(len(GoodZero))[GoodZero|GoodSwap]

for g in GoodExpN:
    e=ext_day[g]
    resolution = res[g]
    axis_onoff = axis[g]
    polarisation = pola[g]
    Target_value = Target_todo

    GoodExtension=(e==ext_day)
    GoodResolution=(res==resolution)
    GoodAxis=(axis==axis_onoff)
    GoodPola=(pola==polarisation)
    GoodRefZeroPhase=(GoodZero|GoodSwap)&GoodResolution&GoodAxis&GoodPola&GoodExtension
        
    dataset[g].Ref_Phase=np.where(GoodRefZeroPhase)[0]
    dataset[g].Ref_Amp=np.where(GoodRefZeroPhase)[0]

name_files=dataset[0].name_files
name_files_check=name_files+".png"

if os.path.isfile(name_files_check)&doNotRedo:   
    raise Exception("Data already there, not redoing...")


#%%
#make intro plot

string_total=dataset[0].h_v["UT"]+":"

for d in dataset:
    h_v=d.h_v
    string_total+="\n"
    string_total+=h_v['NAME']+":  "
    string_total+=h_v["DATE"]+"      "
    string_total+=str(h_v["NDIT"])+" "
    string_total+="%8.2f"%(h_v["DIT"])+"s    "
    string_total+=h_v["RES"]+"   "
    string_total+=h_v["POLA"]+"   "
    string_total+=h_v["AXIS"]+"   "
    string_total+=str(h_v["SXY"])
    if d.target!=0:
        string_total+=" TARGET#%i"%d.target
    if d.swap!=0:
        string_total+=" SWAP#%i"%d.swap
    if d.centred!=0:
        string_total+=" CENTRED#%i"%d.centred
        

string_total+="\n"

for d in dataset:
    h_v=d.h_v
    string_total+="\n"
    string_total+=h_v["DATE"]+" --> "+h_v["DATE_END"]+"  "
    string_total+="airmass=%.2f-%.2f / "%(h_v["AIRM_START"],h_v["AIRM_END"])
    string_total+="tau0=%.1f-%.1f ms / "%(h_v["TAU0_START"]*1e3,h_v["TAU0_END"]*1e3)
    string_total+="seeing=%.2f-%.2f arcsec"%(h_v["SEEING_START"],h_v["SEEING_END"])

string_total+="\n"

for d in dataset:
    h_v=d.h_v
    string_total+="\n"
    string_total+=h_v["verification"][0]

print(string_total)

fig=figure("intro",figsize=(9,1+len(dataset)*.4),clear=True)
t=plt.figtext(.01,.95,string_total,va='top',wrap=1,fontsize=8)
fig_intro=fig

#%%
#process centred

#make plot of on axis data:

# if len(dc)>0:
#     index = [np.ones(len(d.weight),dtype=np.int64) for d in dc]
#     for en,i in enumerate(index):
#         i[:]=en
#     index_exp=np.concatenate(index)
#     Next_files=[]
#     for b in range(6):
#         Next_files+=[np.where(np.diff(index_exp)>0.5)[0]+b*len(index_exp)]
#     Next_files=np.array(Next_files).ravel()

#     visData_ftnormed = np.concatenate([d.visData_ftnormed for d in dc])
#     visData_display=visData_ftnormed*exp(-1j*np.angle(visData_ftnormed.mean(axis=0)))/np.abs(visData_ftnormed).mean(axis=(0,1))
    
#     Npoly_list=[1]
#     fig=make_plot_vis("On-axis star centered Data Amplitude",np.abs(visData_display)[None],Npoly_list,Next_files,real_only=True)
#     fig_centered+=[fig]
#     fig=make_plot_vis("On-axis star centered Data Phase",np.angle(visData_display)[None],Npoly_list,Next_files,real_only=True)
#     fig_centered+=[fig]

dc=[d for d in dataset if d.centred == True]
for d in dc:

    visData_ftnormed_mean=d.visData_ftnormed_weighted.sum(axis=0)/d.weight_ftnormed.sum(axis=0)

    d.amplitude_reference=np.abs(visData_ftnormed_mean)
    d.weight_reference=np.abs(d.weight_ftnormed.sum(axis=0))
    d.phase_reference=np.angle(visData_ftnormed_mean)

#%% 
# Calculate swaps amplitude ratio


# for swap_number in arange(1,Swap_todo+1):

#     ds=[d for d in dataset if np.abs(d.swap) == swap_number]
#     visData_ftnormed_weighted=[d.visData_ftnormed_weighted for d in ds]
#     weight_ftnormed=[d.weight_ftnormed for d in ds]
#     swap=[d.swap for d in ds]

#     Nf=len(visData_ftnormed_weighted)

#     visdata_amp1=[]
#     visdata_amp2=[]
#     visdata_wei1=[]
#     visdata_wei2=[]
#     for i in range(Nf):
#         if swap[i] > 0 :
#             visdata_amp1+=[abs(visData_ftnormed_weighted[i]).sum(axis=(0,1,2))]
#             visdata_wei1+=[abs(weight_ftnormed[i]).sum(axis=(0,1,2))]
#         else:
#             visdata_amp2+=[abs(visData_ftnormed_weighted[i]).sum(axis=(0,1,2))]
#             visdata_wei2+=[abs(weight_ftnormed[i]).sum(axis=(0,1,2))]

#     ratio=np.sum(visdata_amp1)/np.sum(visdata_wei1)/(np.sum(visdata_amp2)/np.sum(visdata_wei2))
#     print("Swap ratio = ",ratio)

#     for i in range(Nf):
#         if swap[i] > 0 :
#             ds[i].swap_ratio=ratio
#         else:
#             ds[i].swap_ratio=1/ratio

    
#%%
# Calculate swaps separations


for swap_number in arange(1,Swap_todo+1):

    ds=[d for d in dataset if np.abs(d.swap) == swap_number]
    visData_weighted=[d.visData_weighted for d in ds]
    # swap_ratio=[d.swap_ratio for d in ds]
    swap=[d.swap for d in ds]
    ucoord_swap=[d.ucoord for d in ds]
    vcoord_swap=[d.vcoord for d in ds]
    ext_day=[d.ext_day for d in ds]
    xy=[d.sxy for d in ds]
    xy=(np.array(xy)*np.sign(np.array(swap))[:,None]).mean(axis=0)


    xy_list=[xy]
    chi2_map_list=[]
    xy_map_list=[]
    print("Searching Swap around position: ",xy)
    for xyrange_map,Nrange in zip(xyrange_map_list_swap,Nrange_list_swap):
        chi2_map,xmap,ymap=af.get_chi2_swap(visData_weighted,ucoord_swap,vcoord_swap,swap,ext_day,xy,xyrange_map,Nrange)
        chi2_map_sum=chi2_map.sum(axis=(0,3))
        i,j=np.unravel_index(chi2_map_sum.argmax(), chi2_map_sum.shape)
        xy=xmap[i],ymap[j]
        xy_list+=[xy]
        chi2_map_list+=[chi2_map]
        xy_map_list+=[[xmap,ymap]]
        print("Found Swap around position: ",xy)


    ds[0].chi2_map_list = chi2_map_list
    ds[0].xy_list = xy_list
    ds[0].xy_map_list = xy_map_list

    Nf=len(visData_weighted)
    for k in range(Nf):
        phase=xy_list[-1][0]*ucoord_swap[k] + xy_list[-1][1]*vcoord_swap[k]
        if swap[k] < 0:
            phase *= -1
        visdata_offset=visData_weighted[k]*exp(1j*phase)

    visdata_amp1=[]
    visdata_amp2=[]
    visdata_wei1=[]
    visdata_wei2=[]
    for d in ds:
        phase=xy_list[-1][0]*d.ucoord + xy_list[-1][1]*d.vcoord
        if d.swap > 0:
            visdata_amp1+=[d.visData_ftnormed_weighted*exp(1j*phase)]
            visdata_wei1+=[d.weight_ftnormed]
        else:
            phase *= -1
            visdata_amp2+=[d.visData_ftnormed_weighted*exp(1j*phase)]
            visdata_wei2+=[d.weight_ftnormed]

    visdata1=np.abs(np.array(visdata_amp1).sum(axis=(0,1))).sum(axis=(0))/np.array(visdata_wei1).sum(axis=(0,1,2))
    visdata2=np.abs(np.array(visdata_amp2).sum(axis=(0,1))).sum(axis=(0))/np.array(visdata_wei2).sum(axis=(0,1,2))
    ratio_spectra=visdata2/(visdata1+1e-12)

    for d in ds:
        phase=xy_list[-1][0]*d.ucoord + xy_list[-1][1]*d.vcoord
        if d.swap > 0:
            d.ratio_spectra = np.sqrt(ratio_spectra)
        else:
            phase *= -1
            d.ratio_spectra = 1/np.sqrt(ratio_spectra+1e-12)

        visdata_amp=d.visData_ftnormed_weighted*exp(1j*phase)
        visdata_weigth=d.weight_ftnormed
        visdata_avg=visdata_amp.sum(axis=0)/visdata_weigth.sum(axis=0)
        d.amplitude_reference=np.abs(visdata_avg*d.ratio_spectra)
        d.weight_reference=np.abs(visdata_weigth.sum(axis=0))
        d.phase_reference=np.angle(visdata_avg)
        a=d.amplitude_reference
        d.ratio_baselines=a.mean(axis=(1))/a.mean(axis=(0,1))
        print("ratio between baselines of swap:",d.ratio_baselines)

    vft=np.array([d.visData_ftnormed for d in ds])
    vftw=np.array([d.visData_ftnormed_weighted for d in ds])
    w=np.array([d.weight_ftnormed for d in ds])
    a=np.array([d.amplitude_reference for d in ds])


#%%
# Computing reference array (phase and amplitude)
# Zeroing phase

for d in dataset:
    target_number=d.target
    Ref_Phase = d.Ref_Phase
    Ref_Amp = d.Ref_Amp
    phase_reference = np.array([dataset[r].phase_reference for r in Ref_Phase])
    phase_reference_weight = np.array([dataset[r].weight_reference for r in Ref_Phase])
    amplitude_reference = np.array([dataset[r].amplitude_reference for r in Ref_Amp])
    amplitude_reference_weight = np.array([dataset[r].weight_reference for r in Ref_Amp])
    amp_mjd=np.array([dataset[r].mjd for r in Ref_Amp])

    phase_reference_complex=np.exp(1j*phase_reference)
    phase_reference_complex[phase_reference_weight<1]=0
    phase_reference_mean=np.angle(phase_reference_complex.mean(axis=0))

    amp_weight=np.double(amplitude_reference_weight>1)

    dif_mjd=(d.mjd-amp_mjd)
    amp_refnumber=abs(dif_mjd).argsort()[:]
    amp_weight[amp_refnumber]*=10
    amplitude_reference_mean=(amplitude_reference*amp_weight).mean(axis=0)/amp_weight.mean(axis=0)
    amp_weight[amp_refnumber]/=10

    if target_number!=0:
        d.phase_reference=phase_reference_mean
        d.amplitude_reference=amplitude_reference_mean

    d.visData*=np.exp(-1j*phase_reference_mean)
    d.visData_weighted*=np.exp(-1j*phase_reference_mean)
    d.visData_ftnormed*=np.exp(-1j*phase_reference_mean)
    d.visData_ftnormed_weighted*=np.exp(-1j*phase_reference_mean)

#%%
# create cleaning matrices

for target_number in arange(1,Target_todo+1):
    dt = [d for d in dataset if d.target == target_number]
    for d in dt:
        wave=d.wave
        xLambda=(wave.mean()/wave-1)/(wave.max()-wave.min())*wave.mean()
        weight=d.weight_ftnormed
        visdata=d.visData_ftnormed
        amplitude_reference=d.amplitude_reference
        Ndit=len(weight)
        Nwave=len(wave)

        d.Npoly_list=[4,5,7]
        d.J_clean=[]
        d.M_clean=[]
        for Npoly in d.Npoly_list:
            d.M_clean+=[np.zeros((Ndit,6,Npoly,Nwave))]
            d.J_clean+=[np.zeros((Ndit,6,Nwave,Npoly))]
            for baseline in range(6):
                J = xLambda[:, None]**range(Npoly) * amplitude_reference[baseline][:, None]
                d.J_clean[-1][:,baseline] = J
                for dit in range(Ndit):
                    W = weight[dit,baseline]
                    JTW = J.T*W
                    JTWJ = np.dot(JTW,J)
                    JTWJinv = pinv(JTWJ)
                    JTWJinv_JTW=np.dot(JTWJinv,JTW)
                    d.M_clean[-1][dit,baseline]=JTWJinv_JTW

        d.visData_cleaned=[]
        d.visData_fitclean = []
        for n in range(len(d.Npoly_list)):
            d.visData_fitclean += [visdata.copy()]
            d.visData_cleaned += [visdata.copy()]
            for baseline in range(6):
                for dit in range(Ndit):
                    coefs = np.dot(d.M_clean[n][dit,baseline],visdata[dit,baseline])
                    fit_data = np.dot(d.J_clean[n][dit,baseline],coefs)
                    d.visData_fitclean[-1][dit,baseline] = fit_data
                    d.visData_cleaned[-1][dit,baseline] -= fit_data

        d.chi2=np.abs(d.visData_cleaned)**2*weight

    # visData_cleaned_mean=[]
    # for d in dt:
    #     visData_cleaned_mean+=[[v.mean(axis=0) for v in d.visData_cleaned]]

    # visData_cleaned_mean=np.array(visData_cleaned_mean).mean(axis=0)

    # for d in dt:
    #     for n in range(len(d.Npoly_list)):
    #         d.visData_cleaned[n]-=visData_cleaned_mean[n]

#%%
# fitting data


for target_number in arange(1,Target_todo+1):


    dt=[d for d in dataset if d.target == target_number]
    print("Ndit = ",np.sum([len(d.weight) for d in dt]))
    Nexp=len(dt)

    weight = np.concatenate([d.weight_ftnormed for d in dt])
    Ndit,Nb,Nwave = weight.shape
    ucoord=np.concatenate([d.ucoord for d in dt])
    vcoord=np.concatenate([d.vcoord for d in dt])
    ucoord_diff=np.concatenate([d.ucoord_diff*np.ones((len(d.weight),Nb,1)) for d in dt])
    vcoord_diff=np.concatenate([d.vcoord_diff*np.ones((len(d.weight),Nb,1)) for d in dt])
    sxy=np.concatenate([d.sxy*np.ones((len(d.weight),Nb,1)) for d in dt])
    amplitude_reference=np.concatenate([d.amplitude_reference*np.ones((len(d.weight),1,1)) for d in dt])

    weight=weight.reshape((Ndit*Nb,Nwave))
    ucoord=ucoord.reshape((Ndit*Nb,Nwave))
    vcoord=vcoord.reshape((Ndit*Nb,Nwave))
    ucoord_diff=ucoord_diff.reshape((Ndit*Nb,Nwave))
    vcoord_diff=vcoord_diff.reshape((Ndit*Nb,Nwave))
    sxy=sxy.reshape((Ndit*Nb,2))
    amplitude_reference=amplitude_reference.reshape((Ndit*Nb,Nwave))

    xy_search = dt[0].xy_search
    Npoly_list = dt[0].Npoly_list
    Npoly=len(Npoly_list)

    visData_cleaned = []
    for i in range(Npoly):
        visData_cleaned += [[d.visData_cleaned[i] for d in dt]]
        visData_cleaned[-1] = np.concatenate(visData_cleaned[-1]).reshape((Ndit*Nb,Nwave))

    M_clean = []
    for i in range(Npoly):
        M_clean += [[d.M_clean[i] for d in dt]]
        M_clean[-1] = np.concatenate(M_clean[-1]).reshape((Ndit*Nb,-1,Nwave))

    J_clean = []
    for i in range(Npoly):
        J_clean += [[d.J_clean[i] for d in dt]]
        J_clean[-1] = np.concatenate(J_clean[-1]).reshape((Ndit*Nb,Nwave,-1))

    index = [np.ones((len(d.weight),Nb),dtype=np.int64) for d in dt]
    for en,i in enumerate(index):
        if Nexp > 3:
            i[:]=en
        else :
            Ni=len(i)
            i[:]=en
            i[(Ni+1)//2:]=en+Nexp
    index_exp=np.concatenate(index).reshape((Ndit*Nb))
    for en,i in enumerate(index):
        i[:]=arange(Nb)
    index_base=np.concatenate(index).reshape((Ndit*Nb))

    print("Searching Target around position: ",xy_search)
    xy=np.ones((2,Npoly))*xy_search[:,None]
    
    result_fit=[]
    result_xy=[]
    for Nrange,xyrange_map in zip(Nrange_list_target,xyrange_map_list_target):
        result=af.get_chi2_target(visData_cleaned,ucoord,vcoord,ucoord_diff,vcoord_diff,weight,amplitude_reference,M_clean,J_clean,sxy,xy,xyrange_map,Nrange,Npoly_list,index_exp,index_base,use_weight2=use_weight2,contrast_fixed=contrast_fixed)
        xy_total,xy_base,xy_exp=af.get_xy_from_results(result)
        print("xy =",xy_total)
        xy=xy_total
        result_fit+=[result]
        result_xy+=[[xy_total,xy_base,xy_exp]]
    visData_planet=af.get_chi2_target(visData_cleaned,ucoord,vcoord,ucoord_diff,vcoord_diff,weight,amplitude_reference,M_clean,J_clean,sxy,xy,xyrange_map,Nrange,Npoly_list,index_exp,index_base,visPlanet=True,use_weight2=use_weight2,contrast_fixed=contrast_fixed)

    dt[0].result_fit=result_fit
    dt[0].visData_planet=visData_planet.reshape((Npoly,Ndit,Nb,Nwave))
    dt[0].result_xy=result_xy

    if make_spectra:

        amplitude_reference_sum=amplitude_reference.sum(axis=1)
        spectra=np.zeros((Npoly,Nwave))
        coVariance=np.zeros((Npoly,Nwave,Nwave))

        for n in range(Npoly):
            print("extracting spectra with Npoly = %i"%(Npoly_list[n]-1))
            phase=xy[0,n]*ucoord + xy[1,n]*vcoord
            vis_data_planet=amplitude_reference*np.exp(-1j*phase)

            theta=np.sqrt((xy[0,n]-sxy[:,0])**2+(xy[1,n]-sxy[:,1])**2)
            if dt[0].h_v["UT"]=="UT1":
                flux_inj=af.getFlux(theta,diam=8.0)
            else:
                flux_inj=af.getFlux(theta,diam=1.8)
            vis_data_planet*=flux_inj[:,None]

            coefs=np.matmul(M_clean[n],vis_data_planet[:,:,None])
            visData_planet_cleaned=(vis_data_planet[:,:,None]-np.matmul(J_clean[n],coefs))[:,:,0]
            amp_cleaned=np.abs(visData_planet_cleaned).sum(axis=1)/(amplitude_reference_sum+1e-30)
            weight2=weight*amp_cleaned[:,None]

            coefs=M_clean[n]*vis_data_planet[:,None]
            visData_planet_cleaned=np.matmul(J_clean[n],-coefs)
            for w in range(Nwave):
                visData_planet_cleaned[:,w,w]+=vis_data_planet[:,w]
            visData_planet_cleaned=visData_planet_cleaned.reshape(-1,Nwave)

            W_visData_planet_cleaned=np.conj(weight2.ravel()[:,None]*visData_planet_cleaned)
            denominator=np.dot(W_visData_planet_cleaned.T,visData_planet_cleaned).real
            numerator=np.dot(W_visData_planet_cleaned.T,visData_cleaned[n].ravel()).real
            JWJ_inv=pinv(denominator)
            spectra[n]=np.dot(JWJ_inv,numerator)
            coVariance[n]=(JWJ_inv+JWJ_inv.T)/2

        dt[0].spectra=np.array(spectra)
        dt[0].coVariance=np.array(coVariance)



#%%
# plot data on axis:


fig_centered=[]
dc=[d for d in dataset if d.centred == True]
res_ext_day=np.array([d.h_v["RES"]+" "+d.ext_day  for d in dc])

for red in np.unique(res_ext_day):
    dred=np.array(dc)[res_ext_day==red]
    index = [np.ones(len(d.weight),dtype=np.int64) for d in dred]
    for en,i in enumerate(index):
        i[:]=en
    index_exp=np.concatenate(index)
    Next_files=[]
    for b in range(6):
        Next_files+=[np.where(np.diff(index_exp)>0.5)[0]+b*len(index_exp)]
    Next_files=np.array(Next_files).ravel()

    visData_ftnormed = np.concatenate([d.visData_ftnormed for d in dred])
    visData_display=visData_ftnormed*exp(-1j*np.angle(visData_ftnormed.mean(axis=0)))/np.abs(visData_ftnormed).mean(axis=(0,1))
    
    Npoly_list=[1]
    fig=af.make_plot_vis("On-axis star centered Data Amplitude "+red,np.abs(visData_display)[None],Npoly_list,Next_files,real_only=True)
    fig_centered+=[fig]
    fig=af.make_plot_vis("On-axis star centered Data Phase "+red,np.angle(visData_display)[None],Npoly_list,Next_files,real_only=True)
    fig_centered+=[fig]


#%%
# plot data swap

fig_swap=[]


for swap_number in arange(1,Swap_todo+1):
    ds=[d for d in dataset if np.abs(d.swap) == swap_number]
    ratio_baselines=[d.ratio_baselines for d in ds]
    chi2_map_list = ds[0].chi2_map_list
    xy_list = ds[0].xy_list
    sxy = ds[0].sxy
    xy_map_list = ds[0].xy_map_list

    name_fig="chi2 swap %i"%swap_number
    Data = [c.transpose((1,2,0,3)).mean(axis=3)[:,:,:,None] for c in chi2_map_list]
   
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])
    xmap = [x[0][:,None]*np.ones(Nx2) for x in xy_map_list]
    ymap = [x[1][:,None]*np.ones(Nx2) for x in xy_map_list]
    Y_label=[""]
    X_title=["" for i in range(Nx)]

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,samevaxis=True,label_axis_inside=False)
    fig_swap+=[fig]
    fig.suptitle(ds[0].h_v['NAME']+" RA=%.3f mas, DEC=%.3f mas"%(xy_list[-1][0],xy_list[-1][1]))
            

    name_fig="chi2 swap per Baseline %i"%swap_number
    Data = [c.transpose((1,2,0,3)) for c in chi2_map_list]
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])
    xmap = [x[0][:,None]*np.ones(Nx2) for x in xy_map_list]
    ymap = [x[1][:,None]*np.ones(Nx2) for x in xy_map_list]
    Y_label=["Baseline %i"%i for i in range(1,Ny+1)]
    X_title=["" for i in range(Nx)]

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,samevaxis=False,label_axis_inside=True)
    fig_swap+=[fig]
    
    ratio_spectra=ds[0].ratio_spectra
    fig,ax=plt.subplots(2,num="plot swap flux ratio between component B and A %i"%swap_number,figsize=(9,6),clear=True)
    ax[0].errorbar(wave*1e6,1/ratio_spectra,fmt='o-',markeredgecolor='k')
    ax[0].set_xlabel("Wavelength (microns)")
    ax[0].set_ylabel("Contrast ratio")
    ax[1].plot(ratio_baselines,'o-')
    ax[1].set_xlabel("Nexp")
    ax[1].set_ylabel("Flux ratio between baselines")
    fig_swap+=[fig]

    

#%%
# plot data and calibrator

for target_number in arange(1,Target_todo+1):
    fig_target=[]
    fig_main=[]

    dt = [d for d in dataset if d.target == target_number]
    xy_search=dt[0].xy_search
    wave=dt[0].wave
    Ref_Phase = dt[0].Ref_Phase
    Ref_Amp = dt[0].Ref_Amp
    phase_reference = np.array([dataset[r].phase_reference for r in Ref_Phase])
    amplitude_reference = np.array([dataset[r].amplitude_reference for r in Ref_Amp])
    phase_reference_weight = np.array([dataset[r].weight_reference for r in Ref_Phase])
    amplitude_reference_weight = np.array([dataset[r].weight_reference for r in Ref_Amp])


    phase_reference_complex=np.exp(1j*phase_reference)
    phase_reference_complex[phase_reference_weight<1]=0
    phase_reference_final=np.angle(phase_reference_complex.mean(axis=0))

    phase=np.angle(np.exp(1j*(phase_reference-phase_reference_final)))
    phase[phase_reference_weight<1]=0


    amp_weight=np.double(amplitude_reference_weight>1)
    amplitude_reference_final=(amplitude_reference*amp_weight).mean(axis=0)/amp_weight.mean(axis=0)

    amp=amplitude_reference/amplitude_reference_final
    amp[amplitude_reference_weight<1]=1

    fig,ax=plt.subplots(6,num="amplitude ref",figsize=(9,7),clear=True,sharex=True)
    for i in arange(6):
        ax[i].plot(wave*1e6,amp[:,i].T)
        ax[i].set_xlim([min(wave*1e6),max(wave*1e6)])
        # ax[i].set_ylim([0,1.9])

    ax[0].set_title("Amplitude reference value")
    ax[5].set_xlabel("wavelength (microns)")
    ax[3].set_ylabel("Amplitude of ref visibilities (normalized to mean)")
    fig.subplots_adjust(wspace=0, hspace=0)
    fig_target+=[fig]

    fig,ax=plt.subplots(6,num="phase ref",figsize=(9,7),clear=True,sharex=True)
    for i in arange(6):
        ax[i].plot(wave*1e6,phase[:,i].T*180/np.pi)
        ax[i].set_xlim([min(wave*1e6),max(wave*1e6)])
    ax[0].set_title("Phase reference value")
    ax[5].set_xlabel("wavelength (microns)")
    ax[3].set_ylabel("Phase of ref visibilities (degrees)")
    fig.subplots_adjust(wspace=0, hspace=0)
    fig_target+=[fig]

    if dt[0].mode_Amp=="on-axis":
       fig_target+=fig_centered[0::2]

    if dt[0].mode_Phase=="on-axis":
       fig_target+=fig_centered[1::2]

    weight = np.concatenate([d.weight for d in dt])
    amplitude_reference=np.concatenate([d.amplitude_reference*np.ones((len(d.weight),1,1)) for d in dt])
    index = [np.ones(len(d.weight),dtype=np.int64) for d in dt]
    for en,i in enumerate(index):
        i[:]=en
    index_exp=np.concatenate(index)
    Next_files=[]
    for b in range(6):
        Next_files+=[np.where(np.diff(index_exp)>0.5)[0]+b*len(index_exp)]
    Next_files=np.array(Next_files).ravel()

    visData_ftnormed = np.concatenate([d.visData_ftnormed for d in dt])
    visData_cleaned = []
    visData_fitclean = []

    Npoly_list = dt[0].Npoly_list
    for n in range(len(Npoly_list)):
        visData_cleaned += [np.concatenate([d.visData_cleaned[n] for d in dt])]
        visData_fitclean += [np.concatenate([d.visData_fitclean[n] for d in dt])]

    fig=af.make_plot_vis("Dataset Raw",np.array(visData_ftnormed/amplitude_reference)*np.ones((3,1,1,1)),Npoly_list,Next_files)
    fig_target+=[fig]
    fig=af.make_plot_vis("Dataset Cleaned for stellar light",np.array(visData_cleaned)/amplitude_reference,Npoly_list,Next_files)
    fig_target+=[fig]

    visData_planet=dt[0].visData_planet

    fig=af.make_plot_vis("Dataset Exoplanet fitted",np.array(visData_planet)/amplitude_reference,Npoly_list,Next_files)
    fig_target+=[fig]
    fig=af.make_plot_vis("Dataset Cleaned for stellar light and exoplanet",(np.array(visData_cleaned)-np.array(visData_planet))/amplitude_reference,Npoly_list,Next_files)
    fig_target+=[fig]


# plot data target


    dt = [d for d in dataset if d.target == target_number]
    sxy = dt[0].sxy

    mjd =  np.mean([d.mjd for d in dt])
    result_fit=dt[0].result_fit
    result_xy = dt[0].result_xy
    result_xy[0][0]

    xy_final=[r[0][:,:,None] for r in result_xy]
    xy_exp=[r[2] for r in result_xy]
    xy_base=[r[1] for r in result_xy]
    best_xy=np.array(xy_final[-1][:,:,0])

    chi2_total=[r[0][:,:,:,None] for r in result_fit]
    xmap=[r[1] for r in result_fit]
    ymap=[r[2] for r in result_fit]
    chi2_exp=[r[3] for r in result_fit]
    chi2_base=[r[4] for r in result_fit]
    amp_cleaned=[r[5] for r in result_fit]
    uv_cleaned=[r[6] for r in result_fit]
    contrast_total=[r[7][:,:,:,None] for r in result_fit]
    contrast_exp=[r[8] for r in result_fit]
    contrast_base=[r[9] for r in result_fit]


    Npoly=len(Npoly_list)
    Nrange_list=len(contrast_total)

    store_print=""
    for p in range(Npoly):
        covM=np.cov([xy_exp[-1][0,p],xy_exp[-1][1,p]])/(len(xy_exp[-1][0,p]))
        covM+=np.identity(2)*0.02**2
        try:
            egenVal,egenVec=np.linalg.eig(covM)
        except:
            egenVal,egenVec=np.linalg.eig(np.identity(covM.shape[0]))

        PA_rms=(np.arctan2(egenVec[0,0],egenVec[1,0])*180/pi,np.arctan2(egenVec[0,1],egenVec[1,1])*180/pi)
        PA_rms=((array(PA_rms)+360+90)%180)-90

        printA="Reduced with polynomial = %i"%(Npoly_list[p]-1)
        printB="MJD= %6.3f RA = %3.3f mas, DEC = %3.3f mas"%(mjd,best_xy[0,p],best_xy[1,p])
        printC="EigenVal = (%.3fmas ,%4.2fdeg) , (%.3fmas ,%4.2fdeg)"%(sqrt(egenVal[0]), PA_rms[0], sqrt(egenVal[1]),PA_rms[1])
        printD="std RA/DEC/rho: [%6.3f, %6.3f, %6.3f]"%(sqrt(covM[0,0]),sqrt(covM[1,1]),covM[0, 1]/(covM[0, 0]**0.5*covM[1, 1]**0.5))

        store_print+=printA+"\n"+printB+"\n"+printC+"\n"+printD+"\n"+"\n"

    print(store_print)


    name_fig="chi2 total"+"_%.2f_%.2f"%(xy_search[0],xy_search[1])
    Data=[(c-c.max(axis=(0,1))) for c in chi2_total]
    Data_star=[c for c in xy_final]
    Data_circle=[c for c in xy_exp]
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])

    Y_label=["Polynomial = %i"%(Npoly_list[i]-1) for i in range(0,Ny)]
    X_title=["" for i in range(Nx)]

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star,Data_circle,figsize=(12,15.5),samevaxis=True,label_axis_inside=True,show_chi2=True)
    fig_main+=[fig]
    fig.suptitle(store_print)
    fig.tight_layout()
    name_fig="Contrast total"
    Data=[c for c in contrast_total]
    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star,Data_circle,figsize=(12,15.5),samevaxis=True,label_axis_inside=True,show_contrast=True)
    fig_main+=[fig]
    fig.suptitle(store_print)
    fig.tight_layout()

    if make_spectra==True:
            
        spectra=dt[0].spectra
        coVariance=dt[0].coVariance

        labs=["Polynomial = %i"%(npoly-1) for npoly in Npoly_list]
        fig,ax=plt.subplots(1,1,num="Spectra",figsize=(12,6),clear=True)    
        for p in arange(len(Npoly_list)):
            err=sqrt(np.diag(coVariance[p]))
            err[err>np.median(err)*10]=np.median(np.abs(spectra[p]))
            plt.errorbar(wave*1e6,spectra[p],yerr=err,fmt='o-',zorder=0,label=labs,markeredgecolor='k')
        ax.set_xlabel("Wavelength (microns)")
        ax.set_ylabel("Contrast ratio")
        fig_main+=[fig]


    name_fig=r"$\eta$ from uv rotation per baseline"
    Data=[c for c in uv_cleaned]
    Data_star=[c*np.ones((1,1,6)) for c in xy_final]
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])

    Y_label=["Baseline %i"%i for i in range(1,Ny+1)]
    X_title=["" for i in range(Nx)]
    for npoly in range(Npoly):
        X_title[npoly*3+1]="Polynomial = %i"%(Npoly_list[npoly]-1)

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star,samevaxis=True,label_axis_inside=True,show_amp=True)
    fig_target+=[fig]


    name_fig=r"$\eta_{\rm processing}$ per baseline"
    Data=[c for c in amp_cleaned]
    Data_star=[c*np.ones((1,1,6)) for c in xy_final]
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])

    Y_label=["Baseline %i"%i for i in range(1,Ny+1)]
    X_title=["" for i in range(Nx)]
    for npoly in range(Npoly):
        X_title[npoly*3+1]="Polynomial = %i"%(Npoly_list[npoly]-1)

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star,samevaxis=True,label_axis_inside=True,show_amp=True)
    fig_target+=[fig]

    name_fig=r"$\chi_2$ per baseline"
    Data=[c-c.max(axis=(0,1)) for c in chi2_base]
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])

    Y_label=["Baseline %i"%i for i in range(1,Ny+1)]
    X_title=["" for i in range(Nx)]
    for npoly in range(Npoly):
        X_title[npoly*3+1]="Polynomial = %i"%(Npoly_list[npoly]-1)

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star,samevaxis=True,label_axis_inside=True,show_chi2=True)
    fig_target+=[fig]
    name_fig="Contrast per baseline"
    Data=[c for c in contrast_base]
    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,Data_star,samevaxis=False,label_axis_inside=True,show_contrast=True)
    fig_target+=[fig]



    name_fig=r"$\chi_2$ per exposure"
    Data=[c-c.max(axis=(0,1)) for c in chi2_exp]
    Nx1=len(Data)
    Nx2=len(Data[0][0,0])
    Nx=Nx1*Nx2
    Ny=len(Data[0][0,0,0])

    Y_label=["Exposure %i"%i for i in range(1,Ny+1)]
    X_title=["" for i in range(Nx)]
    for npoly in range(Npoly):
        X_title[npoly*3+1]="Polynomial = %i"%(Npoly_list[npoly]-1)

    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,xy_exp,xy_exp,samevaxis=True,label_axis_inside=True,show_chi2=True)
    fig_target+=[fig]
    name_fig="Contrast per exp"
    Data=[c for c in contrast_exp]
    fig=af.make_plot_chi2(name_fig,Data,xmap,ymap,Y_label,X_title,sxy,xy_exp,xy_exp,samevaxis=True,label_axis_inside=True,show_contrast=True)
    fig_target+=[fig]



    fig_list=[fig_intro]
    fig_list+=fig_main
    fig_list+=fig_target
    fig_list+=fig_swap


    names=np.array([d.h_v['NAME'] for d in dt])
    name_unique=np.unique(names)
    number=np.array([np.sum(n==names) for n in name_unique])
    Name_object=name_unique[number.argmax()].replace(" ","")
    Date=dataset[0].h_v["DATE"][:10]

    name_file_target="AstroL_"+Name_object+"_"+Date+"_%.2f_%.2f"%(xy_search[0],xy_search[1])

    if use_weight2==False:
        name_file_target+="_shortB"
    if fast == True:
        name_file_target+="_fast"
    else:
        name_file_target+="_slow"

    result_fit=dt[0].result_fit
    result_xy = dt[0].result_xy
    chi2_total=[r[0] for r in result_fit]
    contrast_total=[r[7] for r in result_fit]
    x=[]
    y=[]
    c=[]
    d=[]
    for p in arange(3):
        delta_chi2=np.max(chi2_total[-1][:,:,p])-np.min(chi2_total[-1][:,:,p])
        argmin_chi2=chi2_total[-1][:,:,p].argmin()
        xyc=np.unravel_index(argmin_chi2,chi2_total[-1][:,:,p].shape)
        contrast_avg=contrast_total[-1][xyc[0],xyc[1],p]
        c+=[contrast_avg]
        d+=[delta_chi2]
        x+=[dt[0].result_xy[-1][0][0,p]]
        y+=[dt[0].result_xy[-1][0][1,p]]

    delta_c=[]
    delta_pos=[]
    sum_chi2_50=[]
    sum_chi2_100=[]
    for i in range(3):
        for j in range(i+1,3):
            delta_c+=[(c[i]-c[j])*2/(c[i]+c[j])<0.5]
            delta_pos+=[np.sqrt((x[i]-x[j])**2+(y[i]-y[j])**2)<1.5]
            sum_chi2_50+=[(d[i]+d[j])/2>50]
            sum_chi2_100+=[(d[i]+d[j])/2>100]

    gF=np.array(delta_c)&np.array(delta_pos)&np.array(sum_chi2_50)
    if np.sum(gF) > 1:
        if np.sum(sum_chi2_100) >1:
            name_file_target+="_good"
        else:
            name_file_target+="_oki"
    elif np.sum(gF) > 0:
        name_file_target+="_oki"
    else:
        name_file_target+='_bad'
    
    dt[0].name_file_target=name_file_target

    print(name_file_target+'.pdf')
    with PdfPages(name_file_target+'.pdf') as pdf:
        for fi in fig_list:
            fig=figure(fi.get_label())
            pdf.savefig()



# name_files=dataset[0].name_files
# np.savez(name_files+".npz",dataset=dataset)


# %%


for target_number in arange(1,Target_todo+1):
    dt = [d for d in dataset if d.target == target_number]
    name_file_target = dt[0].name_file_target
    print(name_file_target)
    mjd =  np.mean([d.mjd for d in dt])
    sxy =  np.mean([d.sxy for d in dt],axis=0)
    result_xy = dt[0].result_xy
    xy_final=[r[0] for r in result_xy]
    xy_exp=[r[2] for r in result_xy]
    result_fit=dt[0].result_fit
    chi2_total=[r[0] for r in result_fit]
    contrast_total=[r[7] for r in result_fit]
    wave=dt[0].wave
    
    name = dt[0].h_v["NAME_SC"]
    resolution = dt[0].h_v["RES"]
    date_obs = dt[0].h_v["DATE"]

    h_oifits={}
    try:
        dir_name=os.getcwd().split('/')[-2]
        h_oifits["OBJ_ALT"] = dir_name
    except:
        pass
    h_oifits["PROGID"] = dt[0].h_v["PROG_ID"]
    h_oifits["STATION"] = dt[0].h_v["UT"]
    h_oifits["RA"] = dt[0].h_v["RA"]
    h_oifits["DEC"] = dt[0].h_v["DEC"]

    h_oifits["X_FIBER"] = sxy[0]
    h_oifits["Y_FIBER"] = sxy[1]

    h_oifits["AXIS"] = dt[0].h_v["AXIS"]
    h_oifits["POLA"] = dt[0].h_v["POLA"]


    Ref_Phase = dt[0].Ref_Phase
    Ref_Amp = dt[0].Ref_Amp
    names_phase_reference = np.array([dataset[r].h_v['NAME'] for r in Ref_Phase])
    names_amplitude_reference = np.array([dataset[r].h_v['NAME'] for r in Ref_Amp])
    names_phase_reference=np.unique(names_phase_reference)
    names_amplitude_reference=np.unique(names_amplitude_reference)

    h_oifits["REF_PHAS"] = ",".join(names_phase_reference)
    h_oifits["MODE_PHA"] = dt[0].mode_Phase
    h_oifits["REF_AMP"] = ",".join(names_amplitude_reference)
    h_oifits["MODE_AMP"] = dt[0].mode_Amp

    h_oifits["NEXP"] = len(dt)
    h_oifits["DIT"] = dt[0].h_v['DIT']
    h_oifits["NDIT"] = dt[0].h_v['NDIT']
    h_oifits["DATE_STA"] = dt[0].h_v['DATE'] 
    h_oifits["DATE_END"] = dt[-1].h_v['DATE'] 

    airmass_start = [d.h_v["AIRM_START"] for d in dt]
    airmass_end = [d.h_v["AIRM_END"] for d in dt]

    tau_start = [d.h_v["TAU0_START"] for d in dt]
    tau_end = [d.h_v["TAU0_END"] for d in dt]

    seeing_start = [d.h_v["SEEING_START"] for d in dt]
    seeing_end = [d.h_v["SEEING_END"] for d in dt]

    h_oifits["AIRM_MIN"] = np.array([airmass_start,airmass_end]).min()
    h_oifits["AIRM_MAX"] = np.array([airmass_start,airmass_end]).max()

    h_oifits["TAU0_MIN"] = np.array([tau_start,tau_end]).min()
    h_oifits["TAU0_MAX"] = np.array([tau_start,tau_end]).max()

    h_oifits["SEEI_MIN"] = np.array([seeing_start,seeing_end]).min()
    h_oifits["SEEI_MAX"] = np.array([seeing_start,seeing_end]).max()


    for p in arange(3)[::-1]:
        p2=3-p
        Npoly=dt[0].Npoly_list[p]-1
        h_oifits["NPOLY_%i"%p2] = Npoly
        h_oifits["X_%i"%p2] = np.round(dt[0].result_xy[-1][0][0,p],3)
        h_oifits["Y_%i"%p2] = np.round(dt[0].result_xy[-1][0][1,p],3)
        covM=np.cov([xy_exp[-1][0,p],xy_exp[-1][1,p]])/(len(xy_exp[-1][0,p]))
        covM+=np.identity(2)*0.02**2
        try:
            egenVal,egenVec=np.linalg.eig(covM)
            x_err,y_err,rho=(sqrt(covM[0,0]),sqrt(covM[1,1]),covM[0, 1]/(covM[0, 0]**0.5*covM[1, 1]**0.5))
            h_oifits["X_ERR_%i"%p2] = np.round(x_err,3)
            h_oifits["Y_ERR_%i"%p2] = np.round(y_err,3)
            h_oifits["XY_RHO_%i"%p2] = np.round(rho,4)
        except:
            egenVal,egenVec=np.linalg.eig(np.identity(covM.shape[0]))
        delta_chi2=np.max(chi2_total[-1][:,:,p])-np.min(chi2_total[-1][:,:,p])
        h_oifits["CHI2_%i"%p2] = np.round(delta_chi2,1)
        argmin_chi2=chi2_total[-1][:,:,p].argmin()
        xyc=np.unravel_index(argmin_chi2,chi2_total[-1][:,:,p].shape)
        contrast_avg=contrast_total[-1][xyc[0],xyc[1],p]
        h_oifits["CONTRA_%i"%p2] = contrast_avg
        print(contrast_avg)

    contrast=dt[0].spectra[::-1]
    coVariance=dt[0].coVariance[::-1]
    mean_c=np.mean([np.diag(c).mean() for c in coVariance])
    coVariance+=np.identity(len(wave))*mean_c/100

    print("Writing "+name_file_target+".fits")
    af.saveFitsSpectrum(name_file_target+".fits", wave*1e-6, contrast, coVariance, h_oifits, name, date_obs, mjd, resolution = resolution)

    

# %%


for swap_number in arange(1,Swap_todo+1):
    ds=[d for d in dataset if np.abs(d.swap) == swap_number]
    wave=ds[0].wave
    mjd =  np.mean([d.mjd for d in ds])
    sxy =  np.mean([d.sxy for d in ds],axis=0)
    result_xy = ds[0].xy_list[-1]
    
    name = ds[0].h_v["NAME_SC"]
    resolution = ds[0].h_v["RES"]
    date_obs = ds[0].h_v["DATE"]

    h_oifits={}

    try:
        dir_name=os.getcwd().split('/')[-2]
        h_oifits["OBJ_ALT"] = dir_name
    except:
        pass

    h_oifits["PROGID"] = ds[0].h_v["PROG_ID"]
    h_oifits["STATION"] = ds[0].h_v["UT"]
    h_oifits["RA"] = ds[0].h_v["RA"]
    h_oifits["DEC"] = ds[0].h_v["DEC"]

    h_oifits["X_FIBER"] = sxy[0]
    h_oifits["Y_FIBER"] = sxy[1]

    h_oifits["AXIS"] = ds[0].h_v["AXIS"]
    h_oifits["POLA"] = ds[0].h_v["POLA"]

    h_oifits["X"] = np.round(result_xy[0],3)
    h_oifits["Y"] = np.round(result_xy[1],3)

    

    ratio_spectra=ds[0].ratio_spectra
    
    print(1/ratio_spectra.mean())
    contrast=1/ratio_spectra
    coVariance=None
    h_oifits["CONTRAST"] = contrast[ratio_spectra!=0].mean()

    h_oifits["NEXP"] = len(ds)
    h_oifits["DIT"] = ds[0].h_v['DIT']
    h_oifits["NDIT"] = ds[0].h_v['NDIT']
    h_oifits["DATE_STA"] = ds[0].h_v['DATE'] 
    h_oifits["DATE_END"] = ds[-1].h_v['DATE'] 

    airmass_start = [d.h_v["AIRM_START"] for d in ds]
    airmass_end = [d.h_v["AIRM_END"] for d in ds]

    tau_start = [d.h_v["TAU0_START"] for d in ds]
    tau_end = [d.h_v["TAU0_END"] for d in ds]

    seeing_start = [d.h_v["SEEING_START"] for d in ds]
    seeing_end = [d.h_v["SEEING_END"] for d in ds]

    h_oifits["AIRM_MIN"] = np.array([airmass_start,airmass_end]).min()
    h_oifits["AIRM_MAX"] = np.array([airmass_start,airmass_end]).max()

    h_oifits["TAU0_MIN"] = np.array([tau_start,tau_end]).min()
    h_oifits["TAU0_MAX"] = np.array([tau_start,tau_end]).max()

    h_oifits["SEEI_MIN"] = np.array([seeing_start,seeing_end]).min()
    h_oifits["SEEI_MAX"] = np.array([seeing_start,seeing_end]).max()


    names=np.array([d.h_v['NAME'] for d in ds])
    name_unique=np.unique(names)
    number=np.array([np.sum(n==names) for n in name_unique])
    Name_object=name_unique[number.argmax()].replace(" ","")
    Date=dataset[0].h_v["DATE"][:10]

    name_file_target="SwapL_"+Name_object+"_"+Date

    if use_weight2==False:
        name_file_target+="_shortB"
    if fast == True:
        name_file_target+="_fast"
    else:
        name_file_target+="_slow"

    print("Writing "+name_file_target+".fits")
    af.saveFitsSpectrum(name_file_target+".fits", wave*1e-6, contrast[None], coVariance, h_oifits, name, date_obs, mjd, resolution = resolution)


print(name_files_check)
fig=figure(fig_intro.get_label())
fig.savefig(name_files_check)


# %%
