from warnings import warn
import os
from astropy.io import fits
import numpy as np
from astropy.modeling import models, fitting, Fittable2DModel
from pkg_resources import resource_filename

roof_pos = np.array([38.49, 38.54, 38.76, 39.80])
roof_x = np.array([274.4, 787.1, 1236.1, 1673.4])
roof_y = np.array([242.3, 247.7, 225.8, 235.6])
northern_stations=['G2', 'J3', 'J4', 'J5', 'J6']
approx_scale = {'AT' : 80., 'UT' : 18.}

def do_all(data, dark, theta_in, rho_in, win, approx_PA, group, plot, array, hdulist):
    ngroups=int(data.shape[0])/group
    if ngroups == 0:
        ngroups=1

    strx=dict()
    stry=dict()
    try:
        nx=hdulist[0].header["HIERARCH ESO DET1 FRAMES NX"]
        ny=hdulist[0].header["HIERARCH ESO DET1 FRAMES NY"]
    except KeyError:
        nx=512
        ny=512
    for p in np.arange(4):
        try:
            strx[p]=hdulist[0].header["HIERARCH ESO DET1 FRAM{} STRX".format(p+1)]
        except KeyError:
            print(("key not found: "+"HIERARCH ESO DET1 FRAM{} STRX".format(p+1)))
            strx[p]=1
        try:
            stry[p]=hdulist[0].header["HIERARCH ESO DET1 FRAM{} STRY".format(p+1)]
        except KeyError:
            print(("key not found: "+"HIERARCH ESO DET1 FRAM{} STRY".format(p+1)))
            stry[p]=1
    dit=hdulist[0].header["HIERARCH ESO DET1 SEQ1 DIT"]

    x1s=np.zeros((ngroups,4))
    y1s=np.zeros((ngroups,4))
    x2s=np.zeros((ngroups,4))
    y2s=np.zeros((ngroups,4))
    apparent_PA=np.zeros((ngroups,4))
    apparent_sep=np.zeros((ngroups,4))

    for g in np.arange(ngroups):
        avg_data=np.mean(data[g*group:(g+1)*group, :, :], axis=0)
        if type(dark) != float:
            avg_data -= dark

        for p in np.arange(4):
            port=avg_data[0:ny, p*nx:(p+1)*nx]
            if type(dark) is float:
                if dit in master_darks:
                    port -= master_darks[dit][stry[p]:stry[p]+ny, strx[p]:strx[p]+nx]
            x1, y1, x2, y2 = fit_port(port, approx_PA[p], roof_x[p]-strx[p], roof_y[p]-stry[p], rho_in, win, plot, array)
            x1s[g, p]=x1+strx[p]
            y1s[g, p]=y1+stry[p]
            x2s[g, p]=x2+strx[p]
            y2s[g, p]=y2+stry[p]
            dx = x2-x1
            dy = y2-y1
            apparent_sep[g, p]= np.sqrt(dx*dx+dy*dy)
            apparent_PA[g, p] = np.arctan2(dx, dy)*180./np.pi

    return apparent_PA, apparent_sep, x1s, y1s, x2s, y2s


class File:
    """RAW or P2VMRED GRAVITY file, exposing the acquisition camera field observables.

    TODO:
        Would need to document this better.

        We consider three objects in the general case:
         - the FT object (e.g. IRS16C)
         - the first SC object in the mapping sequence (e.g. S2), for
           which OFFX and OFFY are trully 0.
         - the current SC object (e.g. SgrA*).

        Several of the complications here try to alleviate the fact
        that OFFX and OFFY are stored in the raw files only since July
        2017.


    Args:
        fname (str): name of file to read.
        rho  (dict): override separation inferred from header, for estimating scale.
        first_offsets (dict): override INS.SOBJ.X/Y
        group (int): number of frames to average together
        dark:
        window:
        plot_fit (int): whether to plot something when fitting (for debugging).
        first_file (str or File): for P2VMRED files processed from raw
            files older than July 2017 (lacking INS.SOBJ.OFFX/OFFY),
            one may provide in first_file the first_file after the
            acquisition template. Various observables will be
            corrected: field_scale, field_sc_fiber_dx/dy, offx/offy,
            dx_in/dy_in.

    Attributes:
        rho: see Args
        filename: see fname in Args
        first_offsets: see Args
        group: see Args
        dark: see Args
        window: see Args
        plot_fit: see Args
        FTname (str): self.hdulist[0].header["HIERARCH ESO FT ROBJ NAME"]
        SCname (str): self.hdulist[0].header["HIERARCH ESO INS SOBJ NAME"]
        key (str): essentially "FTname:SCname", used as key for rho and first_offsets
        key_reversed (str): essentially "SCname:FTname", idem
        dx_in: INS.SOBJ.X (x offset from FT to SC target)
        dy_in: INS.SOBJ.Y (y offset from FT to SC target)
        offx: INS.SOBJ.OFFX or 0. (x mapping offset)
        offy: INS.SOBJ.OFFY or 0. (y mapping offset)
        rho_in (mas): separation of first SC object to FT object
        theta_in (deg): position angle on sky on FT --> first SC vector
        config (str): summary of the VLTI configuration
        array (str) 'AT' or 'UT'

        field_ft_x, field_ft_y: detected pixel position of FT target
            on each acq frame (or group of frames). It fname is a raw
            file, a fit is perfomed. Else the result from the DRS is
            used.

        field_sc_x, field_sc_y: same for first SC object (assuming
            it's the brightest in its vicinity).

        field_scale: scale estimated from apparent separation and rho_in, mas/pix

        field_sc_fiber_dx, field_sc_fiber_dy: correction that should
            be applied to the SC fibre to put it on the current SC
            object, assuming the tip-tilt is servoed to put the FT
            target on the FT fiber.

        field_*err: corresponding uncertainties estimated by the DRS.

        hdulist: fits HDU list

        offangles: self.hdulist[0].header['HIERARCH ESO INS OFFANG*']

        approx_PA: approximate position angle of the binary, from x to y.

        imaging_data_acq: the acquisition camera data cube

        oi_vis_acq: for P2VMRED files, the OI_VIS_ACQ extension

        nx, ny, strx, stry: the detector windowing information
            (HIERARCH ESO DET1 FRAM*)

        nframes: numbere of frames

        date: MJD of each frame (or group of frames)

        pcr_acq_start: HIERCARCH ESO PCR ACQ START as a numpy datime64 object

        fiber_ft_x, fiber_ft_y, fiber_sc_x, fiber_sc_y: position of
            the fibers projected on the camera. Except at zenith,
            those position are offset from the source images by a
            common amount due to atmospheric dispersion between the H
            band and the K band.

    """

    def __init__(self, fname, rho=None, first_offsets=None, group=None, dark=0.,
                 window=20, plot_fit=0, first_file=None):
        # Alias the header and keep the file name
        try:
            self.hdulist = fits.open(fname)
        except(IOError):
            fname = "./" + fname
            self.hdulist = fits.open(fname)
        self.filename = fname

        # Store user parameters
        self.rho=rho
        self.first_offsets=first_offsets
        if group is not None:
            self.group=group
        self.dark=dark
        self.window=window
        self.plot_fit=plot_fit

        # Alias some useful values
        self.FTname=self.hdulist[0].header["HIERARCH ESO FT ROBJ NAME"]
        self.SCname=self.hdulist[0].header["HIERARCH ESO INS SOBJ NAME"]
        self.key=self.FTname.replace(':', '::') + ':' + self.SCname.replace(':', '::')
        self.key_reversed=self.SCname.replace(':', '::') + ':' + self.FTname.replace(':', '::')

        # First initialize according to header
        self.dx_in=self.hdulist[0].header["HIERARCH ESO INS SOBJ X"]
        self.dy_in=self.hdulist[0].header["HIERARCH ESO INS SOBJ Y"]

        try:
            self.offx=self.hdulist[0].header["HIERARCH ESO INS SOBJ OFFX"]
            self.offy=self.hdulist[0].header["HIERARCH ESO INS SOBJ OFFY"]
        except(KeyError):
            self.offx=0.
            self.offy=0.

        orig_dx=self.dx_in - self.offx
        orig_dy=self.dy_in - self.offy

        self.rho_in=np.sqrt(orig_dx*orig_dx+orig_dy*orig_dy)
        if self.rho is not None:
            if self.key in self.rho:
                self.rho_in=self.rho[self.key]
            elif self.key_reversed in self.rho:
                self.rho_in=self.rho[self.key_reversed]
            elif None in self.rho:
                self.rho_in=self.rho[None]
        if self.first_offsets is not None:
            if self.rho is None:
                self.rho=dict()
            if (self.key not in self.rho and
                self.key_reversed not in self.rho and
                None not in self.rho):
                self.rho[self.key]=self.rho_in
            if self.key not in self.first_offsets:
                self.first_offsets[self.key]=[orig_dx,
                                              orig_dy,
                                              np.arctan2(orig_dx,
                                                         orig_dy)*180./np.pi]
            if self.key_reversed not in self.first_offsets:
                self.first_offsets[self.key_reversed]=[-orig_dx,
                                                        -orig_dy,
                                                        np.arctan2(-orig_dx,
                                                                   -orig_dy)
                                                        *180./np.pi]

        self.theta_in=np.arctan2(orig_dx,orig_dy)*180./np.pi

        sta=dict()
        GVPORTS={7: 'GV1', 5: 'GV2', 3: 'GV3', 1:'GV4'}
        self.config=''
        self.array='AT'
        for t in ('4', '3', '2', '1'):
            port=GVPORTS[self.hdulist[0].header["HIERARCH ESO ISS CONF INPUT"+t]]
            tel=self.hdulist[0].header["HIERARCH ESO ISS CONF T"+t+"NAME"]
            if tel[0:2] == 'UT':
                self.array='UT'
            sta[port]=self.hdulist[0].header["HIERARCH ESO ISS CONF STATION"+t]
            self.config += ' {port}: {tel}/{sta}/{dl}'.format(port=port, tel=tel, sta=sta[port], dl=self.hdulist[0].header["HIERARCH ESO ISS CONF DL"+t])

        if first_file is not None:
            if type(first_file) is str:
                S2_ff = File(first_file)
            else:
                S2_ff = first_file
            scale_correction =  (np.sqrt(S2_ff.dx_in**2+S2_ff.dy_in**2)
                                 /np.sqrt((self.dx_in-self.offx)**2+(self.dy_in-self.offy)**2))
            self.field_scale *= scale_correction
            self.field_scaleerr *= scale_correction

            real_offx = self.dx_in - S2_ff.dx_in
            real_offy = self.dy_in - S2_ff.dy_in

            offx_correction = real_offx - self.offx
            offy_correction = real_offy - self.offy

            self.theta_in = S2_ff.theta_in
            self.offx=real_offx
            self.offy=real_offy

            PA_S2_FT  = self.theta_in+180.
            PA_correction = np.arctan2(offx_correction, offy_correction)/np.pi*180.
            aPA_S2_FTi = np.arctan2(self.field_ft_x-self.field_sc_x, self.field_ft_y-self.field_sc_y)/np.pi*180.
            aPA_correction = aPA_S2_FTi+(PA_correction-PA_S2_FT)
            arho_correction = np.sqrt(offx_correction**2+offy_correction**2)/self.field_scale

            self.field_sc_fiber_dx += arho_correction*np.sin(aPA_correction/180.*np.pi)
            self.field_sc_fiber_dy += arho_correction*np.cos(aPA_correction/180.*np.pi)


    def __getattr__(self, attrname):
        '''Fill attributes the first time they are requested'''

        # offset angles
        if attrname == "offangles":
            self.offangles=np.empty(4)
            for p in np.arange(4):
                try:
                    self.offangles[p] = self.hdulist[0].header['HIERARCH ESO INS OFFANG'+str(p+1)]
                except (KeyError):
                    self.offangles[p] = 0

            return getattr(self, attrname)

        if attrname == "approx_PA":
            rp=np.zeros(4)
            for p in np.arange(4):
                try:
                    rp[p]=self.hdulist[0].header["HIERARCH ESO INS DROTOFF"+str(p+1)]
                    if rp[p] == 0.:
                        rp[p] = roof_pos[p]
                except (KeyError):
                    rp[p]=roof_pos[p]
            self.approx_PA = 270.-rp

            return getattr(self, attrname)

        # default for group depending on whether we need to fit
        if attrname == "group":
            if self.oi_vis_acq is not None:
                self.group = 1
            else:
                nframes=self.imaging_data_acq.shape[0]
                self.group=nframes//4
                if self.group is 0:
                    self.group = 1

            return getattr(self, attrname)

        # imaging_data_acq attribute
        if attrname == "imaging_data_acq":
            hdus = [ hdu.data for hdu in self.hdulist[1:] if hdu.header["EXTNAME"]=="IMAGING_DATA_ACQ" ]
            if len(hdus) == 0:
                self.imaging_data_acq = None
            else:
                s=np.max([d.shape[0] for d in hdus])
                self.imaging_data_acq = [d for d in hdus if d.shape[0] == s][0]
            return getattr(self, attrname)

        # oi_vis_acq attribute, only for reduced files
        if attrname in {"oi_vis_acq",
                        "oi_vis_met"}:
            try:
                data = self.hdulist[attrname.upper()].data
            except (KeyError):
                data = None

            setattr(self, attrname, data)

            return getattr(self, attrname)

        # Camera windowing
        if attrname in {"nx",
                        "ny",
                        "strx",
                        "stry"}:
            self.strx=np.zeros(4)
            self.stry=np.zeros(4)
            try:
                self.nx=self.hdulist[0].header["HIERARCH ESO DET1 FRAMES NX"]
                self.ny=self.hdulist[0].header["HIERARCH ESO DET1 FRAMES NY"]
            except KeyError:
                self.nx=512
                self.ny=512
            for p in np.arange(4):
                try:
                    self.strx[p]=self.hdulist[0].header["HIERARCH ESO DET1 FRAM{} STRX".format(p+1)]
                except KeyError:
                    print(("key not found: "+"HIERARCH ESO DET1 FRAM{} STRX".format(p+1)))
                    self.strx[p]=1
                try:
                    self.stry[p]=self.hdulist[0].header["HIERARCH ESO DET1 FRAM{} STRY".format(p+1)]
                except KeyError:
                    print(("key not found: "+"HIERARCH ESO DET1 FRAM{} STRY".format(p+1)))
                    self.stry[p]=1

            return getattr(self, attrname)

        # Actual number of frames
        elif attrname == "nframes":
            if self.oi_vis_acq is not None:
                self.nframes=len(self.oi_vis_acq["TIME"])//4
            else:
                self.nframes=self.imaging_data_acq.shape[0]

            return getattr(self, attrname)

        # MJD of each sample
        elif attrname == "date":

            period=np.timedelta64(int(1e6*self.hdulist[0].header['HIERARCH ESO INS TIM2 PERIOD']), 'us')
            acqsta=np.datetime64(self.hdulist[0].header['HIERARCH ESO INS TIM2 START'])

            nrec = self.nframes // self.group
            if nrec==0:
                nrec=1
            data = (acqsta +
                    0.5*period*self.group +
                    period*self.group*np.arange(nrec))

            setattr(self, attrname, data)

            return getattr(self, attrname)

        # MJD of recording start
        elif attrname == "pcr_acq_start":
            self.pcr_acq_start = np.datetime64(self.hdulist[0].header['HIERARCH ESO PCR ACQ START'])
            return getattr(self, attrname)

        # Fibre positions
        elif attrname in {"fiber_ft_x",
                          "fiber_ft_y",
                          "fiber_sc_x",
                          "fiber_sc_y"}:
            for fiber in {"ft", "sc"}:
                for axis in {"x", "y"}:
                    data=np.zeros(4)
                    for p in np.arange(4):
                        data[p]=self.hdulist[0].header["ESO ACQ FIBER "+fiber.upper()+str(p+1)+axis.upper()]
                    aname="fiber_"+fiber+"_"+axis
                    setattr(self, aname, data)
            return getattr(self, attrname)

        # Fitted positions and errors from MET
        elif attrname in {"met_time",
                          "met_mjd",
                          "met_field_fiber_dx",
                          "met_field_fiber_dy"}:
            data=self.oi_vis_met[attrname[4:].upper()]
            nrec=len(data)//4
            setattr(self, attrname, np.reshape(data, (nrec, 4)))
            return getattr(self, attrname)

        # Fitted positions and errors.
        elif attrname in {"field_sc_x",
                          "field_sc_y",
                          "field_ft_x",
                          "field_ft_y",
                          "field_fiber_dx",
                          "field_fiber_dy",
                          "field_scale",
                          "field_sc_xerr",
                          "field_sc_yerr",
                          "field_ft_xerr",
                          "field_ft_yerr",
                          "field_fiber_dxerr",
                          "field_fiber_dyerr",
                          "field_scaleerr"}:
            if self.oi_vis_acq is not None:
                if self.group == 1:
                    data=self.oi_vis_acq[attrname.upper()]
                    nrec= len(data)//4
                    setattr(self, attrname, np.reshape(data, (nrec, 4)))
                else:
                    group = self.group
                    nrec = self.nframes//group
                    if nrec == 0:
                        nrec == 1
                        group = self.nframes
                    keep=nrec*group
                    x=np.ma.array(self.oi_vis_acq["FIELD_FIBER_DX"])
                    y=np.ma.array(self.oi_vis_acq["FIELD_FIBER_DY"])
                    ex=np.ma.array(self.oi_vis_acq["FIELD_FIBER_DXERR"])
                    ey=np.ma.array(self.oi_vis_acq["FIELD_FIBER_DYERR"])
                    x=np.reshape(x[:keep*4], (nrec, group, 4))
                    y=np.reshape(y[:keep*4], (nrec, group, 4))
                    ex=np.reshape(ex[:keep*4], (nrec, group, 4))
                    ey=np.reshape(ey[:keep*4], (nrec, group, 4))

                    filter_out =  ((x==0.) + (y==0.))
                    mx=np.ma.masked_where(filter_out, x)
                    my=np.ma.masked_where(filter_out, y)

                    ox=np.ma.average(mx, axis=1, weights=1./ex**2)
                    oy=np.ma.average(my, axis=1, weights=1./ey**2)

                    oex=np.ma.sqrt(np.ma.average((mx-ox[:,None,:])**2, axis=1, weights=1./ex**2))
                    oey=np.ma.sqrt(np.ma.average((my-oy[:,None,:])**2, axis=1, weights=1./ey**2))

                    setattr(self, "field_fiber_dx",ox)
                    setattr(self, "field_fiber_dy",oy)
                    setattr(self, "field_fiber_dxerr",oex)
                    setattr(self, "field_fiber_dyerr",oey)
                    for fiber in {"SC", "FT"}:
                        x=np.ma.array(self.oi_vis_acq["FIELD_"+fiber+"_X"])
                        y=np.ma.array(self.oi_vis_acq["FIELD_"+fiber+"_Y"])
                        ex=np.ma.array(self.oi_vis_acq["FIELD_"+fiber+"_XERR"])
                        ey=np.ma.array(self.oi_vis_acq["FIELD_"+fiber+"_YERR"])
                        x=np.reshape(x[:keep*4], (nrec, group, 4))
                        y=np.reshape(y[:keep*4], (nrec, group, 4))
                        ex=np.reshape(ex[:keep*4], (nrec, group, 4))
                        ey=np.reshape(ey[:keep*4], (nrec, group, 4))

                        filter_out =  ((x==0.) + (y==0.))
                        mx=np.ma.masked_where(filter_out, x)
                        my=np.ma.masked_where(filter_out, y)

                        ox=np.ma.average(mx, axis=1, weights=1./ex**2)
                        oy=np.ma.average(my, axis=1, weights=1./ey**2)

                        oex=np.ma.sqrt(np.ma.average((mx-ox[:,None,:])**2, axis=1, weights=1./ex**2))
                        oey=np.ma.sqrt(np.ma.average((my-oy[:,None,:])**2, axis=1, weights=1./ey**2))

                        setattr(self, "field_"+fiber.lower()+"_x",ox)
                        setattr(self, "field_"+fiber.lower()+"_y",oy)
                        setattr(self, "field_"+fiber.lower()+"_xerr",oex)
                        setattr(self, "field_"+fiber.lower()+"_yerr",oey)

                # This is a reduced file
            else:
                # Probably a raw file: fill the 8 attributes by fitting
                try:
                    nx=self.hdulist[0].header["HIERARCH ESO DET1 FRAMES NX"]
                    ny=self.hdulist[0].header["HIERARCH ESO DET1 FRAMES NY"]
                except KeyError:
                    nx=512
                    ny=512
                data=self.imaging_data_acq[:,:ny,:]

                win=int(self.window)//2

                apparent_PA, apparent_sep, self.field_ft_x, self.field_ft_y, self.field_sc_x, self.field_sc_y = do_all(data, self.dark, self.theta_in, self.rho_in,
                                win, self.approx_PA, self.group, self.plot_fit,
                                self.array, self.hdulist)

                self.field_scale=self.rho_in/apparent_sep

                # sc_fiber_dx/dy
                if self.first_offsets is not None:
                    ddx=self.dx_in - self.first_offsets[self.key][0]
                    ddy=self.dy_in - self.first_offsets[self.key][1]
                    drho=np.sqrt(ddx*ddx+ddy*ddy)/self.field_scale.mean(axis=0)
                    dth=np.arctan2(ddx, ddy)*180./np.pi
                    PA=self.approx_PA+dth-self.first_offsets[self.key][2]
                    ddx_pix = drho*np.sin(PA/180.*np.pi)
                    ddy_pix = drho*np.cos(PA/180.*np.pi)
                else:
                    ddx_pix = np.zeros(4)
                    ddy_pix = np.zeros(4)

                self.field_sc_fiber_dx = (self.field_sc_x-self.field_ft_x)-(self.fiber_sc_x-self.fiber_ft_x)+ddx_pix
                self.field_sc_fiber_dy = (self.field_sc_y-self.field_ft_y)-(self.fiber_sc_y-self.fiber_ft_y)+ddy_pix

                for aname in ["field_sc_xerr",
                              "field_sc_yerr",
                              "field_ft_xerr",
                              "field_ft_yerr",
                              "field_sc_fiber_dxerr",
                              "field_sc_fiber_dyerr",
                              "field_scaleerr"]:
                    setattr(self, aname, None)

            return getattr(self, attrname)

        # Unkown attribute: raise corresponding exception
        raise AttributeError ("'" + __name__ + "'" +
                              " object has no attribute " +
                              "'" + attrname + "'")


def load_master_darks(common_calib_dir="../common_calibration"):
    global master_darks
    master_darks=dict()
    for DIT, fname in ((0.7, "gvacq_MasterSkyField.fits"), (2.8, "gvacq_MasterSkyField_2.8.fits")):
        try:
            master_darks[DIT]=fits.open(fname)[0].data[0,:,:]
        except IOError:
            try:
                master_darks[DIT]=fits.open(common_calib_dir+"/"+fname)[0].data[0,:,:]
            except IOError:
                try:
                    master_darks[DIT]=fits.open(os.environ['INS_ROOT']+"/SYSTEM/REFDATA/"+fname)[0].data[0,:,:]
                except (IOError, KeyError):
                    try:
                        master_darks[DIT]=fits.open(resource_filename(__name__, 'data/'+fname))[0].data[0,:,:]
                    except:
                        warn('Could not load master dark '+fname)
    return master_darks

master_darks=dict()
load_master_darks()

def fit_port(port, approx_PAp, roof_xp, roof_yp, rho, win, plot, array):

    approx_dx=rho*np.sin(approx_PAp*np.pi/180)/approx_scale[array];
    approx_dy=rho*np.cos(approx_PAp*np.pi/180)/approx_scale[array];

    xmax=int(roof_xp-0.5*approx_dx)
    ymax=int(roof_yp-0.5*approx_dy)

    thumb=port[ymax-win:ymax+win+1, xmax-win:xmax+win+1]

    y, x = np.mgrid[-win:win+1, -win:win+1]

    g_init=models.Gaussian2D(amplitude=thumb.max(), x_mean=0, y_mean=0, x_stddev=3., y_stddev=3., theta=0.)
    fit_g = fitting.LevMarLSQFitter()
    g = fit_g(g_init, x, y, thumb)

    if plot >= 2:
        plt.figure(figsize=(8, 2.5))
        plt.subplot(1, 3, 1)
        plt.imshow(thumb, origin='lower', interpolation='nearest', vmin=0, vmax=thumb.max())
        plt.title("Data")
        plt.subplot(1, 3, 2)
        plt.imshow(g(x, y), origin='lower', interpolation='nearest', vmin=0, vmax=thumb.max())
        plt.title("Model")
        plt.subplot(1, 3, 3)
        plt.imshow(thumb - g(x, y), origin='lower', interpolation='nearest', vmin=0, vmax=thumb.max())
        plt.title("Residual")
        plt.show()

    x1 = xmax + g.x_mean
    y1 = ymax + g.y_mean

    x2max=int(roof_xp+0.5*approx_dx)
    y2max=int(roof_yp+0.5*approx_dy)

    thumb2=port[y2max-win:y2max+win+1, x2max-win:x2max+win+1]

    g2_init=models.Gaussian2D(amplitude=thumb2.max(), x_mean=0, y_mean=0, x_stddev=g.x_stddev, y_stddev=g.y_stddev, theta=0.)
    g2_init.x_stddev.fixed=True
    g2_init.y_stddev.fixed=True
    g2 = fit_g(g2_init, x, y, thumb2)

    x2 = x2max+g2.x_mean
    y2 = y2max+g2.y_mean

    if plot >= 2:
        plt.figure(figsize=(8, 2.5))
        plt.subplot(1, 3, 1)
        plt.imshow(thumb2, origin='lower', interpolation='nearest', vmin=0, vmax=thumb2.max())
        plt.title("Data")
        plt.subplot(1, 3, 2)
        plt.imshow(g2(x, y), origin='lower', interpolation='nearest', vmin=0, vmax=thumb2.max())
        plt.title("Model")
        plt.subplot(1, 3, 3)
        plt.imshow(thumb2 - g2(x, y), origin='lower', interpolation='nearest', vmin=0, vmax=thumb2.max())
        plt.title("Residual")
        plt.show()

    dx = x2-x1
    dy = y2-y1

    return x1, y1, x2, y2

class Filter:
    '''Filter functors to be used with filter()
    '''

    def __init__(self,
                 exclude_object=None,
                 include_object=None,
                 keywords_equal=None):
        self.exclude_object=exclude_object
        self.include_object=include_object
        self.keywords_equal=keywords_equal

    def __call__(self, fn):
        '''True if fn is the name of a suitable file'''
        # Remove files that can't be opened as FITS
        # or don't contain FT and SC name
        try:
            header=fits.open(fn)[0].header
            FTname=header["HIERARCH ESO FT ROBJ NAME"]
            SCname=header["HIERARCH ESO INS SOBJ NAME"]
            pair=FTname.replace(':', '::') + ':' + SCname.replace(':', '::')
        except (IOError, KeyError):
            return False
        # Exclude as per exclude_object
        if self.exclude_object is not None:
            if pair in self.exclude_object:
                return False
        # If include_object has been used, exclude all other objects.
        if self.include_object is not None:
            if pair not in self.include_object:
                return False
        # Check keywords that must match
        for key in self.keywords_equal:
            try:
                if header[key] != self.keywords_equal[key]:
                    return False
            except KeyError:
                return False
        # All good, valid file
        return True
