import numpy as np
from matplotlib import pyplot as plt
import astropy.io.fits as pyfits
import os
import scipy.ndimage
import dpfit
import time

"""
import graviArc

graviArc.reducedArLines(graviArc.filesAug['Med'][0], graviArc.filesAug['Med'][1], fitLines=1)
graviArc.reducedArLines(graviArc.filesSept['Med'][0], graviArc.filesSept['Med'][1], fitLines=1)

graviArc.reducedArLines(graviArc.filesAug['High'][0], graviArc.filesAug['High'][1], fitLines=1)
"""

directory = '/Users/amerand/DATA/GRAVITY/ARC/20150807'

filesAug = {'High':(os.path.join(directory,'SC_HighPol_argon.fits'),
                 os.path.join(directory,'SC_HighPol_argon_dark.fits')),
        'Med':(os.path.join(directory, 'SC_MedPol_argon2.fits'),
               os.path.join(directory, 'SC_MedPol_argon2_dark.fits'))}

directory = '/Users/amerand/DATA/GRAVITY/ARC/20150908'

filesSept = {'Med':(os.path.join(directory, '2015-09-09_Argon_MedPol.fits'),
         os.path.join(directory, '2015-09-09_Argon_MedPol_dark.fits'))}


def analyseRawArc( arcfile, darkfile, nLines=10):
    f = pyfits.open(arcfile)
    d = pyfits.open(darkfile)
    image = f[0].data - d[0].data
    image -= np.median(image, axis=0)[None,:]
    image = scipy.ndimage.filters.median_filter(image, size=3, mode='nearest')
    f.close()
    d.close()

    # -- find and keep track brighest lines
    if 'High' in arcfile:
        eraseWidth = 32
    else: # -- medium
        eraseWidth = 6
        if image.shape[1] == 2048:
            image = image[:,950:1250]
        print('IMAGE.SHAPE', image.shape)

    spect = image[image.shape[0]/2-20:image.shape[0]/2+20,:].sum(axis=0)
    spect = np.percentile(image, 90, axis=0)
    spect -= np.median(spect)
    # -- remove continuum:
    spect -= scipy.ndimage.filters.median_filter(spect, size=4*eraseWidth, mode='nearest')
    # if 'Med' in arcfile: # assumes data were taken in full frame
    #     spect[1200:] = 0

    lines = [] # index where lines are found
    test = True
    spectRms = np.percentile(spect, 84)-np.percentile(spect, 16)
    plt.figure(99)
    plt.clf()
    plt.plot(spect)
    #plt.hlines(6*spectRms, 0, len(spect), color='y')

    for k in range(nLines):
        lines.append(spect.argmax())
        plt.text(lines[-1], spect.max(), str(k))
        # -- erase line to find the next one
        spect[lines[-1]-eraseWidth/2:lines[-1]+eraseWidth/2] = 0.0

    print('FOUND %d LINES'%len(lines))
    # plt.figure(100)
    # plt.clf()
    # plt.subplot(211)
    # Xi, Yi = np.meshgrid(np.arange(image.shape[1]), np.arange(image.shape[0]))
    # plt.pcolormesh(Xi, Yi, image, cmap='hot',
    #             vmin=np.percentile(image, 1), vmax=np.percentile(image, 99))
    # plt.vlines(lines, plt.ylim()[0],  plt.ylim()[1], 'c', alpha=0.1)
    # plt.xlim(400,2048); plt.ylim(0, 700)

    if 'High' in arcfile:
        width = 14 # width of the band on wich we will do the fit
    elif 'Med' in darkfile:
        width=8

    allFits = []
    for i,l in enumerate(lines): # -- series of fits
        print('#'*12, i, '#'*30)
        tmp = np.transpose(image[:, l-int(width/2.):l+int(width/2.)])
        # -- remove background
        tmp -= np.median(tmp)
        # -- remove edges from fit
        if 'High' in arcfile:
            if image.shape[0]>300:
                image[:,:5] = 0
                image[:,-5:] = 0


        # -- remove low pix:
        tmp[tmp<-5*(np.percentile(tmp, 84)-np.percentile(tmp, 16))] = 0
        X, Y = np.meshgrid(np.arange(tmp.shape[1]), np.arange(l-int(width/2),
                            l+int(width/2)))
        if i==0:
            # -- first guess for the vertical position of the spectral lines
            mask = 1.0*((tmp - tmp.min()) > 3*(np.percentile(tmp, 84)-np.percentile(tmp, 16))).max(axis=0)
            a0y = np.sum(mask*np.arange(tmp.shape[1]))/np.sum(mask)
            print('a0y=', a0y)
            guess = {'a0x':l, 'a1x':0.0, 'a0y':a0y,
                    'a1y':12.40, 'dXpol':-0.05, 'dYpol':-0.1,
                    'continuum': 0.01,'offset':0.0}
            guess['e0'] = tmp.max()/2.
            if 'High' in arcfile:
                guess['a2x'] = -0.004
                guess['widthX']=1.0
                guess['widthY']=1.0

            elif 'Med' in arcfile:
                guess['a1y'] = 6.0
                if '201509' in arcfile:
                    guess['a1y'] = 7.0
                guess['a2x'] = 0.0
                guess['width'] = 1.0

            for k in ['e1', 'e2', 'e3', 'e4', 'e5']:
                guess[k] = 1.0

        else:
            guess = allFits[0]['best'].copy()
            guess['a0x']=l
            guess['e0']=tmp.max()/2.
            if 'High' in arcfile:
                # -- some educated guess
                guess['a0y'] = allFits[0]['best']['a0y'] - \
                    (l-allFits[0]['best']['a0x'])*12/2000.
                guess['dXpol'] = allFits[0]['best']['dXpol'] - \
                            (l-allFits[0]['best']['a0x'])*0.3/2000.
                guess['dYpol'] = allFits[0]['best']['dYpol'] + \
                        (l-allFits[0]['best']['a0x'])*0.3/2000.

                guess['a1x'] = np.interp(l, [200, 2000],[15e-3, -5e-3])
                guess['a2x'] = np.interp(l, [0,   2000], [-8e-3, 0e-3])
                guess['a1y'] = np.interp(l, [0, 1000, 2000],[12.5, 12.3, 12.5])
            elif 'Med' in arcfile:
                guess['a1x'] = np.interp(l, [50, 250],[-25e-3, -35e-3])
                guess['a2x'] = np.interp(l, [50, 250], [-1e-3, 0e-3])
                guess['dYpol'] = allFits[0]['best']['dYpol'] + \
                        (l-allFits[0]['best']['a0x'])*0.3/200.
                if 'blend_dX' in list(guess.keys()):
                    guess.pop('blend_dX')
                if 'blend_e' in list(guess.keys()):
                    guess.pop('blend_e')

        if i!=0:
            doNotFit = ['width', 'widthX', 'widthY', 'e1', 'e2', 'e3', 'e4', 'e5', 'blend_dX']
            # if 'High' in arcfile:
            #     doNotFit.append('continuum')
        else:
            doNotFit=['blend_dX']

        # -- blended lines
        if 'Med' in arcfile and i in [0,1,2]:
            guess['blend_dX'] = -0.17 # based on argon doublet at ~2.22um
            guess['blend_e'] = 0.05

        fit = dpfit.leastsqFit(arcLineModel, (Y,X), guess, tmp.flatten(), doNotFit=doNotFit)
        # -- second passage, check for +- 1 subplot
        bestFit = fit
        if i==0: # only for brightest line
            for dy in [-1, 1, -2, 2]:
                guess2 = guess.copy()
                guess2['a0y'] += dy*guess['a1y']
                fit2 = dpfit.leastsqFit(arcLineModel, (Y,X), guess2, tmp.flatten(),
                                        doNotFit=doNotFit)
                if bestFit['chi2'] > fit2['chi2']:
                    print('dy=', dy, 'leads to better fit! -> a0y=', fit2['best']['a0y'], end=' ')
                    print('chi2', bestFit['chi2'], '->', fit2['chi2'])
                    bestFit = fit2

        for k in list(guess.keys()):
            if not k in doNotFit:
                print(k, guess[k], '->', fit['best'][k], '+-', end=' ')
                print(bestFit['uncer'][k])
        model = arcLineModel((Y,X), bestFit['best'], retImage=True)
        # -- remove lines from image:
        #image[:, l-int(0.9*width/2.):l+int(1.1*width/2.)] -= np.transpose(model)

        xs, ys = arcLineModel(None, bestFit['best'], retXY=True)
        allFits.append(bestFit)
        allFits[-1]['xs'] = xs
        allFits[-1]['ys'] = ys
        allFits[-1]['imageRaw'] = tmp
        allFits[-1]['imageModel'] = model

        plt.close(i)
        plt.figure(i)
        plt.clf()
        ax0 = plt.subplot(411)
        plt.ylabel('residuals')
        plt.pcolormesh(X, Y, tmp-model, cmap='hot')

        plt.subplot(412, sharex=ax0, sharey=ax0)
        plt.ylabel('model')
        plt.pcolormesh(X, Y, model, cmap='hot')

        plt.subplot(413, sharex=ax0, sharey=ax0)
        plt.ylabel('observed')
        plt.pcolormesh(X, Y, tmp, cmap='hot',
                        vmin=np.percentile(tmp, 1),
                        vmax=np.percentile(tmp, 99))
        plt.plot(ys, xs, '+c')
        plt.subplot(414, sharex=ax0)
        plt.plot(tmp.sum(axis=0), '-r', label='observed')
        plt.plot(model.sum(axis=0), '-c', label='model' )
        plt.legend(fontsize=6)

    return allFits

def arcLineModel(xy, param, retXY=False, retImage=False):
    """
    xy = pixels = (x, y) where x and y and 2D arrays
    param = {'a0x':..., 'a0y' 'width':,
            'e1'...'e8':, 'continuum':}
    """
    s = np.arange(-24,24) # all outputs of the IOBC for 4T combination, split pola
    Xs = dpfit.polyN(s, {k[:-1]:param[k] for k in [a for a in list(param.keys()) if a.endswith('x')]})
    Ys = dpfit.polyN(s, {k[:-1]:param[k] for k in [a for a in list(param.keys()) if a.endswith('y')]})
    if 'dXpol' in list(param.keys()):
        Xs[::2] += param['dXpol']
    if 'dYpol' in list(param.keys()):
        Ys[::2] += param['dYpol']

    if retXY:
        return Xs, Ys
    if 'width' in list(param.keys()):
        widthX, widthY = param['width'], param['width']
    if 'widthX' in list(param.keys()):
        widthX = param['widthX']
    if 'widthY' in list(param.keys()):
        widthY = param['widthY']


    res = 0
    for i in range(len(s)):
        tmp = np.exp(-(xy[1]-Ys[i])**2/widthY**2)
        if 'blend_dX' in list(param.keys()) and 'blend_e' in list(param.keys()):
            tmp *= ((1-np.abs(param['blend_e']))*np.exp(-(xy[0]-Xs[i])**2/widthX**2)+
                    param['blend_e']*np.exp(-(xy[0]-Xs[i]-param['blend_dX'])**2/widthX**2))
        elif 'blend' in list(param.keys()):
            tmp[xy[0]-Xs[i] <= 0] *= np.exp(-(xy[0][xy[0]-Xs[i] <= 0]-Xs[i])**2/
                            (widthX*param['blend'])**2)
            tmp[xy[0]-Xs[i] > 0] *= np.exp(-(xy[0][xy[0]-Xs[i] > 0]-Xs[i])**2/
                            widthX**2)
        else:
            tmp *= np.exp(-(xy[0]-Xs[i])**2/widthX**2)

        if 'continuum' in list(param.keys()):
            tmp += param['continuum']*np.exp(-(xy[1]-Ys[i])**2/widthY**2)
        if 'e0' in list(param.keys()) and 'e'+str(int(i/8.)) in list(param.keys()):
            tmp *= param['e0']
            if i>=8:
                tmp *= param['e'+str(int(i/8.))]
        elif 'e' in list(param.keys()):
            tmp *= param['e']
        res += tmp
    if 'offset' in list(param.keys()):
        res += param['offset']
    if retImage:
        return res
    else:
        return res.flatten()

def reducedArLines( arcfile=None, darkfile=None, nLines=10, fitLines=False):
    global _Ar
    # ==== FIT LINES IN DATA:
    if arcfile is None:
        arcfile = '/Users/amerand/DATA/GRAVITY/ARC/SC_HighPol_argon.fits'
    if darkfile is None:
        darkfile = '/Users/amerand/DATA/GRAVITY/ARC/SC_HighPol_argon_dark.fits'
    try:
        n = len(_Ar)
    except:
        fitLines = True
    if fitLines:
        if nLines in [7, 10]:
            _Ar = analyseRawArc( arcfile, darkfile, nLines)
        else:
            print('supports only 7 or 10 bright lines (!)')
            return

    # =================================================
    plt.close(90)
    # =================================================
    f = pyfits.open(arcfile)
    d = pyfits.open(darkfile)
    image = f[0].data - d[0].data
    image -= np.median(image, axis=0)[None,:]
    image -= np.median(image)
    #image = scipy.ndimage.filters.median_filter(image, size=3, mode='nearest')

    if 'Med' in arcfile:
        if image.shape[1] == 2048:
            image = image[:,950:1250]
        print('IMAGE.SHAPE', image.shape)

    # -- remove edges
    if image.shape[0]>300:
        image[:,:5] = 0
        image[:,-5:] = 0
    f.close()
    d.close()


    image = np.abs(image)**0.1

    plt.figure(90, figsize=(16,10))
    plt.subplots_adjust(left=0.07, bottom=0.07, top=0.97, right=0.97)
    plt.clf()

    for j,i in enumerate(np.argsort([a['best']['a0x'] for a in _Ar])):
        axt = plt.subplot(4,2*(len(_Ar)+1),j+1)
        axt.set_yticklabels([])
        axt.set_xticklabels([])
        tmp = np.transpose(_Ar[i]['imageRaw'])[::-1,:]
        axt.imshow(tmp, cmap='gray_r',
                    vmin=np.percentile(tmp, 10), vmax=np.percentile(tmp, 99.9),
                    aspect='auto', interpolation='nearest')
        axt.set_title('X=%d'%(int(_Ar[i]['best']['a0x'])), fontsize=10)
        if j==0:
            axt.set_ylabel('observed')
        axt = plt.subplot(4,2*(len(_Ar)+1),j+1+2*(len(_Ar)+1))
        tmp = tmp - np.transpose(_Ar[i]['imageModel'])[::-1,:]
        lim = max(-np.percentile(tmp, 1.), np.percentile(tmp, 99.))
        axt.imshow(tmp, cmap='RdBu', aspect='auto', interpolation='nearest',
                 vmin=-lim, vmax=lim)
        axt.set_yticklabels([])
        axt.set_xticklabels([])
        if j==0:
            #axt.set_ylabel('model')
            axt.set_ylabel('residuals')
        axt.set_title('F=%dADU\nres+-%3.1f'%(int(_Ar[i]['best']['e0']), lim), fontsize=8)
    det = {}
    for k in list(_Ar[0]['best'].keys()):
        if k.startswith('a') or k.endswith('pol'):
            det[k] = np.array([f['best'][k] for f in _Ar])
            # -- errors weighted by the flux of the line
            det['err_'+k] = np.array([f['uncer'][k]*(_Ar[0]['best']['e0']/f['best']['e0'])**0.5 for f in _Ar])

    for k in list(det.keys()):
        if k!='a0x':
            det[k] = det[k][np.argsort(det['a0x'])]
    det['a0x'] = det['a0x'][np.argsort(det['a0x'])]

    try:
        width = _Ar[0]['best']['widthX']
    except:
        width = _Ar[0]['best']['width']

    fits = {}

    pix = np.arange(image.shape[1])
    p0 = pix.mean()
    print('pix', pix.min(), pix.max())
    #ax0 = plt.subplot(6,2,1)
    ax0 = plt.subplot(12,2,13)

    plt.title('Distortion: polynomial parameters')
    plt.errorbar(det['a0x'], det['a0y'], xerr=det['err_a0x'], yerr=det['err_a0y'],
                    fmt='.', color='k')
    fits['a0y'] = dpfit.leastsqFit(dpfit.polyN, det['a0x']-p0,
                            {'A0':det['a0y'].mean(), 'A1':0.0},
                            det['a0y'], det['err_a0y'])
    plt.plot(pix, dpfit.polyN(pix-p0, fits['a0y']['best']), '-g')
    plt.ylabel('A0Y')
    #plt.grid()

    #plt.subplot(6,2,3, sharex=ax0)
    plt.subplot(12,2,15, sharex=ax0)

    plt.errorbar(det['a0x'], det['a1y'], xerr=det['err_a0x'], yerr=det['err_a1y'],
                    fmt='.', color='k')
    fits['a1y'] = dpfit.leastsqFit(dpfit.polyN, det['a0x']-p0,
                                   {'A0':det['a1y'].mean(), 'A1':0.0, 'A2':0.0},
                                   det['a1y'], det['err_a1y'])
    plt.plot(pix, dpfit.polyN(pix-p0, fits['a1y']['best']), '-g')
    plt.ylabel('A1Y')
    #plt.grid()

    #plt.subplot(6,2,5, sharex=ax0)
    plt.subplot(12,2,17, sharex=ax0)

    plt.errorbar(det['a0x'], 1000*det['a1x'],
                 xerr=det['err_a0x'], yerr=1000*det['err_a1x'],
                 fmt='.', color='k')
    fits['a1x'] = dpfit.leastsqFit(dpfit.polyN, det['a0x']-p0,
                            {'A0':det['a1x'].mean(), 'A1':0.0},
                            det['a1x'], det['err_a1x'])
    plt.plot(pix, 1000*dpfit.polyN(pix-p0, fits['a1x']['best']), '-g')
    plt.ylabel('A1X')
    #plt.grid()

    #plt.subplot(6,2,7, sharex=ax0)
    plt.subplot(12,2,19, sharex=ax0)

    plt.errorbar(det['a0x'], 1000*det['a2x'],
                 xerr=det['err_a0x'], yerr=1000*det['err_a2x'],
                 fmt='.', color='k')
    fits['a2x'] = dpfit.leastsqFit(dpfit.polyN, det['a0x']-p0,
                            {'A0':det['a2x'].mean(), 'A1':0.0},
                             det['a2x'], det['err_a2x'])
    plt.plot(pix, 1000*dpfit.polyN(pix-p0, fits['a2x']['best']), '-g')
    plt.ylabel('A2X')
    #plt.grid()

    #plt.subplot(6,2,9, sharex=ax0)
    plt.subplot(12,2,21, sharex=ax0)

    plt.errorbar(det['a0x'], det['dXpol'], xerr=det['err_a0x'], yerr=det['err_dXpol'],
                    fmt='.', color='k')
    fits['dXpol'] = dpfit.leastsqFit(dpfit.polyN, det['a0x']-p0,
                            {'A0':det['dXpol'].mean(), 'A1':0.0},
                            det['dXpol'], det['err_dXpol'])
    plt.plot(pix, dpfit.polyN(pix-p0, fits['dXpol']['best']), '-g')
    plt.ylabel(r'pol $\Delta$X')
    #plt.grid()

    #plt.subplot(6,2,11, sharex=ax0)
    plt.subplot(12,2,23, sharex=ax0)

    plt.errorbar(det['a0x'], det['dYpol'], xerr=det['err_a0x'], yerr=det['err_dYpol'],
                    fmt='.', color='k')
    fits['dYpol'] = dpfit.leastsqFit(dpfit.polyN, det['a0x']-p0,
                            {'A0':det['dYpol'].mean(), 'A1':0.0},
                            det['dYpol'], det['err_dYpol'])
    plt.plot(pix, dpfit.polyN(pix-p0, fits['dYpol']['best']), '-g')
    plt.xlabel('A0X (channel #24)')
    plt.ylabel(r'pol $\Delta$Y')
    #plt.grid()

    #-- http://www.eso.org/sci/facilities/paranal/decommissioned/isaac/tools/atlas_arcs.pdf
    #-- 7 grightest lines
    if nLines == 7 or len(_Ar)==7:
        ArLines = np.array([2.062186, 2.099184, 2.154009, 2.208321, 2.313952, 2.385154, 2.397306])
    elif nLines == 10 or len(_Ar)==10:
        #-- 10 grightest lines
        ArLines = np.array([1.997118, 2.032256, 2.062186, 2.099184, 2.133871,
                            2.154009, 2.208321, 2.313952, 2.385154, 2.397306])
        # -- in MR, the line at 2.208321 is actually double and not properly resolved!

    plt.subplot(426, sharex=ax0)

    # -- first, convert errs in X to errs in Y:
    if 'High' in arcfile and nLines == 10:
        __w = np.array([0,1,2,3,4,5,6,7,8,9])
    elif 'Med' in arcfile and nLines == 10:
        # -- the line 2,3 and 6 is blended and not properly resolved
        __w = np.array([0,1,4,5,7,8,9])
        #__w = np.array([0,1,2,3,4,5,6,7,8,9])

        for j in range(10):
            if not j in __w:
                plt.text( det['a0x'][j], ArLines[j], 'blended\nline!',
                    color='r', rotation=-70, ha='center', va='center')

    fits['wl'] = dpfit.leastsqFit(dpfit.polyN, det['a0x'][__w]-p0,
                                  {'A0':2.2, 'A1':0.0},
                                  ArLines[__w], verbose=0)
    err = fits['wl']['best']['A1']*det['err_a0x']
    plt.errorbar(det['a0x'], ArLines, yerr=err, fmt='.', color='k')
    if 'High' in arcfile:
        guessW = {'A0':2.2, 'A1':0.0, 'A2':0.0, 'A3':0.0, 'A4':0.0}
    elif 'Med' in arcfile:
        guessW = {'A0':2.2, 'A1':0.0, 'A2':0.0, 'A3':0.0}
        # if len(__w)<9:
        #     guessW = {'A0':2.2, 'A1':0.0, 'A2':0.0}

    fits['wl'] = dpfit.leastsqFit(dpfit.polyN, det['a0x'][__w]-p0,
                                guessW, ArLines[__w], err[__w], verbose=1)
    plt.plot(pix, dpfit.polyN(pix-p0, fits['wl']['best']), '-g',
             label='poly fit (N=%d)'%(len(fits['wl']['best'])-1))
    plt.ylabel('wavelength (um)')
    plt.legend(loc='upper left', fontsize=12)
    plt.ylim(1.95, 2.45)

    plt.subplot(428, sharex=ax0)
    plt.errorbar(det['a0x'][__w], 1000*(ArLines[__w] - fits['wl']['model']),
                1000*err[__w], fmt='.', color='k')
    print(np.std(1000*(ArLines[__w] - fits['wl']['model'])))
    if 'Med' in arcfile:
        for j in range(10):
            if not j in __w:
                    plt.errorbar(det['a0x'][j], 1000*(ArLines[j] -
                        dpfit.polyN(det['a0x'][j]-p0, fits['wl']['best'])),
                        1000*err[j], fmt='.', color='r')

    plt.ylabel('residuals (nm)')
    plt.hlines(0, 0, 2200, linestyle='dashed')
    plt.ylim(-max(-plt.ylim()[0], plt.ylim()[1]),
             max(-plt.ylim()[0], plt.ylim()[1]))
    plt.xlabel('A0X (channel #24)')

    axi = plt.subplot(424, sharex=ax0)
    axi = plt.subplot(222, sharex=ax0)

    #axi.set_title('as seen on SC detector')
    # for a in _Ar:
    #     plt.plot(a['xs'][0::2], a['ys'][0::2], '|b')
    #     plt.plot(a['xs'][1::2], a['ys'][1::2], '|r')

    # spectral lines:
    pix = np.linspace(-20,image.shape[1]+20,100 )
    tmp = []
    for p in pix:
        g = {'a0x':p}
        for k in list(fits.keys()):
            g[k] = dpfit.polyN(p-p0, fits[k]['best'])
        tmp.append(list(arcLineModel(None, g, retXY=True))+[g['wl']])

    axi.imshow(image, cmap='gray_r',
            vmin=np.percentile(image, 10), vmax=np.percentile(image, 99.9),
            aspect='auto', interpolation='nearest')
    try:
        width = _Ar[0]['best']['width']
    except:
        width = _Ar[0]['best']['widthY']
    sigma = width/np.sqrt(2)
    fwhm = 2*np.sqrt(2*np.log(2))*sigma
    plt.title('FWHM = %3.1f pix, %3.2fnm, R=%4.0f'%(fwhm,
            1000*fwhm*fits['wl']['best']['A1'], round(2.2/fwhm/fits['wl']['best']['A1'], 0)))
    # -- write result in a FITS file
    f = pyfits.open(arcfile)
    hdu = pyfits.PrimaryHDU(None)
    hdu.header.update('HIERARCH REDUCED DATE', time.asctime())
    hdu.header.update('HIERARCH ARC FILENAME', os.path.basename(arcfile))
    hdu.header.update('HIERARCH ARC DATE', f[0].header['DATE-OBS'])
    hdu.header.update('HIERARCH DARK FILENAME', os.path.basename(darkfile))
    hdu.header.update('AUTHOR', 'amerand@eso.org')
    cols=[]
    allX = np.arange(image.shape[1])
    cols.append(pyfits.Column(name='X', format='I', array=allX, unit='pix'))

    # -- create synthetic plot and spectrum
    X, Y = np.meshgrid(np.arange(image.shape[1]), np.arange(image.shape[0]))
    spectr = 0.0
    wl = np.linspace(1.96, 2.45, image.shape[0]/2)
    pix = np.linspace(np.min([a['xs'].min() for a in _Ar])-
                    0.1*np.mean([a['xs'].ptp() for a in _Ar]),
                    np.max([a['xs'].max() for a in _Ar])+
                    0.1*np.mean([a['xs'].ptp() for a in _Ar]), image.shape[1])
    print(pix.min(), pix.max())
    Mask, bMask, wMask = 0.0, 0.0, 0.0
    for i in range(48): # -- assumes split mode
        x = np.array([a['xs'][i] for a in _Ar])
        y = np.array([a['ys'][i] for a in _Ar])
        y = y[np.argsort(x)]
        x = x[np.argsort(x)]
        if i==0 or i==47:
            axi.plot(x, y, '|', color='b' if i%2==0 else 'r',
                markersize=9, mew=1.5, alpha=0.5)

        x_ = np.array([t[0][i] for t in tmp])
        y_ = np.array([t[1][i] for t in tmp])
        w_ = np.array([t[2] for t in tmp])
        axi.plot(x_, y_, '-b' if i%2==0 else '-r', alpha=0.3)
        axi.text(x[0]-10, y[0], str(i),
                 color='b' if i%2==0 else 'r',
                 va='center', ha='right' if i%2==0 else 'left', fontsize=9)
        # -- extract spectrum
        try:
            widthY = _Ar[0]['best']['widthY']
        except:
            widthY = _Ar[0]['best']['width']
        mask = np.exp(-(Y-np.interp(pix, x_, y_)[None,:])**2/widthY**2)
        s = np.sum(mask*image, axis=0)/np.sum(mask, axis=0)
        ws = np.interp(pix, x_, w_)
        spectr += np.interp(wl, ws, s)
        cols.append(pyfits.Column(name='Y%02d'%i, format='F7.3',
                    array=np.interp(allX, x_, y_), unit='pix'))
        cols.append(pyfits.Column(name='W%02d'%i, format='F5.3',
                    array=np.interp(allX, x_, w_), unit='microns'))
        # -- spectrum on the detector
        Mask += mask
        bMask +=  (mask > 0.1)
        wMask += (mask > 0.1)*np.interp(np.arange(Mask.shape[1]), x_, w_)[None,:]*1e-6
        axi.set_xlim(x.min()-0.1*x.ptp(),x.max()+0.1*x.ptp())
        plt.ylim(0, image.shape[0])
        axi.set_xlabel('X (pix)')
        axi.set_ylabel('Y (pix)')

    hducols = pyfits.ColDefs(cols)
    hdub = pyfits.new_table(hducols)
    hdub.header.update('EXTNAME', 'SC_MODEL', '')

    # -- another HDU, the equivalent of TEST_WAVE in P2VM:
    print(wMask[wMask>0].min(), wMask.max())
    print(Mask.shape)
    hduc = pyfits.ImageHDU(np.array([bMask,Mask,wMask]))

    hduc.header.update('EXTNAME', 'TEST_WAVE', '')
    for k in ['CRVAL1', 'CRVAL2', 'CRPIX1', 'CRPIX2',
              'CD1_1', 'CD1_2', 'CD2_1', 'CD2_2']:
        hduc.header.update(k, f[0].header[k], '')

    f.close()

    # -- combine all HDUs
    thdulist = pyfits.HDUList([hdu, hdub, hduc])
    #thdulist = pyfits.HDUList([hdu, hdub])

    # -- write file
    if 'High' in arcfile:
        outfile = 'gravity_SC_HR.fits'
        plt.savefig('gravity_SC_HR.pdf')
        plt.savefig('gravity_SC_HR.png')
    elif 'Med' in arcfile:
        outfile = 'gravity_SC_MR.fits'
        plt.savefig('gravity_SC_MR.pdf')
        plt.savefig('gravity_SC_MR.png')

    if os.path.exists(outfile):
        os.remove(outfile)
    print('writting ->', outfile)
    thdulist.writeto(outfile)
    return

def compareArcP2vm(res='HR'):
    if res=='HR':
        shiftx = 5
        shifty = 0
        arc = pyfits.open('./gravity_SC_HR.fits')
        p2vm = pyfits.open('/Users/amerand/DATA/GRAVITY/ARC/GRAVITY.2015-08-14T07-43-51_0003.fits')
    elif res=='MR':
        shiftx = 53
        shifty = 0
        arc = pyfits.open('./gravity_SC_MR.fits')
        p2vm = pyfits.open('/Users/amerand/DATA/GRAVITY/ARC/GRAVITY.2015-08-14T07-13-40_0003.fits')


    wave = p2vm['TEST_WAVE'].data[2] # individual wavelength

    print('TEST_WAVE (p2vm)', wave.shape)
    print('TEST_WAVE (arc)', arc['TEST_WAVE'].data[2].shape)

    print(arc['TEST_WAVE'].data[2][arc['TEST_WAVE'].data[2]>0].min(), end=' ')
    print(arc['TEST_WAVE'].data[2].max())
    print(arc['TEST_WAVE'].data[2][:,53].max())

    plt.close(0)
    plt.figure(0, figsize=(12,6))
    plt.subplots_adjust(left=0.03, right=0.99, top=0.96, bottom=0.05)
    axt = plt.subplot(111)
    vmax = {'MR':12, 'HR':22}
    for i in range(48):
        if i==0:
            print('ARC:', len(arc['SC_MODEL'].data['X']))
        # -- compare position of spectrum
        x = arc['SC_MODEL'].data['X']
        y = arc['SC_MODEL'].data['Y%02d'%i]
        w = arc['SC_MODEL'].data['W%02d'%i]
        x = x[shiftx:][:wave.shape[1]]-shiftx
        y = y[shiftx:][:wave.shape[1]]
        w = w[shiftx:][:wave.shape[1]]

        w_p2vm = np.array([wave[np.rint(y[j]), x[j]] for j in range(len(x))])*1e6
        w_p2vm_p1 = np.array([wave[np.rint(y[j])+1, x[j]] for j in range(len(x))])*1e6
        w_p2vm[w_p2vm==0] = w_p2vm_p1[w_p2vm==0]
        w_p2vm_m1 = np.array([wave[np.rint(y[j])-1, x[j]] for j in range(len(x))])*1e6
        w_p2vm[w_p2vm==0] = w_p2vm_m1[w_p2vm==0]
        #print 1000*(w_p2vm-w).ptp()
        s = axt.scatter(x, y, c=1000*(w_p2vm-w), cmap='RdBu_r', marker='d',
                        vmin=-vmax[res], vmax=vmax[res], edgecolors='None', s=22)

    plt.colorbar(s)
    plt.title('wavelength calibration difference P2VM - Ar (nm), %s mode'%res)
    plt.xlim(x.min(), x.max())
    plt.ylim(0, wave.shape[0])
    arc.close()
    p2vm.close()
    return
