#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Plot various quantities based on OI_VIS_ACQ in P2VMRED files
"""

from astropy.io import fits
import numpy as np
from matplotlib import pyplot as plt, dates as mdates
from matplotlib.lines import Line2D
from dateutil.tz import tzutc
import datetime
from scipy import ndimage
from astropy.modeling import models, fitting, Fittable2DModel
import argparse
import glob
import os
from pkg_resources import resource_filename
from warnings import warn

from . import parser_actions
from . import io
from .io import roof_pos

class Markers:
    def __init__(self):
        self.all_markers=Line2D.filled_markers
        self.cur_marker=0
        self.n_markers=len(self.all_markers)
        self.markers=dict()
    def __getitem__(self, key):
        if key not in list(self.markers.keys()):
            self.markers[key]=self.all_markers[self.cur_marker]
            self.cur_marker = (self.cur_marker+1)%self.n_markers
        return self.markers[key]


def plot_file(fitsfile, ax1, ax2, use_labels, markers, labeled_markers, plot_what, err_threshold, gv_port):
    global cur_marker

def compute_data(fitsfile, plot_what, err_threshold):
    mjds = fitsfile.date.astype(datetime.datetime)

    have_errors = fitsfile.field_sc_xerr is not None
    data_left_err  = None
    data_right_err = None
    field_sc_ex = fitsfile.field_sc_xerr
    field_sc_ey = fitsfile.field_sc_yerr
    field_ft_ex = fitsfile.field_ft_xerr
    field_ft_ey = fitsfile.field_ft_yerr

    field_ft_sc_x=np.ma.array(fitsfile.field_sc_x - fitsfile.field_ft_x)
    field_ft_sc_y=np.ma.array(fitsfile.field_sc_y - fitsfile.field_ft_y)

    filter_out =  ((fitsfile.field_sc_x==0.) +
                   (fitsfile.field_ft_x==0.) +
                   (fitsfile.field_sc_y==0.) +
                   (fitsfile.field_ft_y==0.) )

    if have_errors:
        filter_out += ((fitsfile.field_sc_xerr > err_threshold) +
                       (fitsfile.field_ft_xerr > err_threshold) +
                       (fitsfile.field_sc_yerr > err_threshold) +
                       (fitsfile.field_ft_yerr > err_threshold) )
        field_ft_sc_ex=np.ma.array(np.sqrt(field_sc_ex*field_sc_ex + field_ft_ex*field_ft_ex))
        field_ft_sc_ey=np.ma.array(np.sqrt(field_sc_ey*field_sc_ex + field_ft_ey*field_ft_ey))

    field_ft_sc_x_masked=np.ma.masked_where( filter_out, field_ft_sc_x)
    field_ft_sc_y_masked=np.ma.masked_where( filter_out, field_ft_sc_y)

    field_scale=fitsfile.rho_in/np.sqrt(field_ft_sc_x*field_ft_sc_x+field_ft_sc_y*field_ft_sc_y)
    field_rot_error=0.*field_scale

    rp = np.zeros(4)
    for p in np.arange(4):
        try:
            rp[p]=fitsfile.hdulist[0].header["HIERARCH ESO INS DROTOFF"+str(p+1)]
            if rp[p] == 0.:
                rp[p] = roof_pos[p]
        except (KeyError):
            rp[p]=roof_pos[p]
        try:
            offangle = fitsfile.hdulist[0].header['HIERARCH ESO INS OFFANG'+str(p+1)]
        except (KeyError):
            offangle = 0
        field_rot_error[:,p] = (np.arctan2(field_ft_sc_x[:,p],
                                           field_ft_sc_y[:,p])*180./np.pi
                                -(270-rp[p])-offangle+90.)%360.-90.



    if plot_what == "FToffset":
        data_right = fitsfile.field_ft_y - fitsfile.fiber_ft_y[None,:]
        data_left =  fitsfile.field_ft_x - fitsfile.fiber_ft_x[None,:]
        if have_errors:
            data_left_err = field_ft_ex
            data_right_err = field_ft_ey
    elif plot_what == "SCoffset":
        data_right = fitsfile.field_sc_y - fitsfile.fiber_sc_y[None,:]
        data_left =  fitsfile.field_sc_x - fitsfile.fiber_sc_x[None,:]
        if have_errors:
            data_left_err = field_sc_ex
            data_right_err = field_sc_ey
    elif plot_what == "FTscatter":
        data_left = fitsfile.field_ft_xerr
        data_right = fitsfile.field_ft_yerr
        if have_errors:
            data_left_err = None
            data_right_err = None
    elif plot_what == "SCscatter":
        data_left = fitsfile.field_sc_xerr
        data_right = fitsfile.field_sc_yerr
        if have_errors:
            data_left_err = None
            data_right_err = None
    elif plot_what == "diffscatter":
        data_left = fitsfile.field_fiber_dxerr
        data_right = fitsfile.field_fiber_dyerr
        if have_errors:
            data_left_err = None
            data_right_err = None
    elif plot_what == "offsets":
        data_right=field_ft_sc_y_masked
        data_left=field_ft_sc_x_masked
        if have_errors:
            data_left_err=np.ma.array(np.sqrt(field_sc_ex*field_sc_ex + field_ft_ex*field_ft_ex))
            data_right_err=np.ma.array(np.sqrt(field_sc_ey*field_sc_ex + field_ft_ey*field_ft_ey))
    elif plot_what == "differential" or plot_what == "skydifferential":
        data_right = field_ft_sc_y_masked-(fitsfile.fiber_sc_y-fitsfile.fiber_ft_y)[None,:]
        data_left  = field_ft_sc_x_masked-(fitsfile.fiber_sc_x-fitsfile.fiber_ft_x)[None,:]
        if fitsfile.first_offsets is not None:
            PA_setpoint = fitsfile.first_offsets[fitsfile.key][2]
        else:
            PA_setpoint = np.arctan2(fitsfile.dx_in, fitsfile.dy_in)*180./np.pi
        PA_offset = 270. - rp - PA_setpoint
        if fitsfile.first_offsets is not None:
            ddx=fitsfile.dx_in - fitsfile.first_offsets[fitsfile.key][0]
            ddy=fitsfile.dy_in - fitsfile.first_offsets[fitsfile.key][1]
            drho=np.sqrt(ddx*ddx+ddy*ddy)/field_scale
            dth=np.arctan2(ddx, ddy)*180./np.pi
            PA = dth + PA_offset
            data_left  += drho*np.sin(PA/180.*np.pi)[None,:]
            data_right += drho*np.cos(PA/180.*np.pi)[None,:]
        if have_errors:
            data_left_err=np.ma.array(np.sqrt(field_sc_ex*field_sc_ex + field_ft_ex*field_ft_ex))
            data_right_err=np.ma.array(np.sqrt(field_sc_ey*field_sc_ex + field_ft_ey*field_ft_ey))
        if plot_what == "skydifferential":
            field_scale_masked=np.ma.masked_where( filter_out, field_scale)
            # neglect rotation error
            drho = (np.sqrt(data_left*data_left+data_right*data_right)
                    * field_scale_masked)
            PA = np.arctan2(data_left, data_right)*180./np.pi
            dth = PA - PA_offset
            data_left = drho*np.sin(dth/180.*np.pi)
            data_right = drho*np.cos(dth/180.*np.pi)
            if have_errors:
                # Let's just take the geometric mean
                e = np.sqrt(data_left_err*data_right_err)*field_scale_masked
                data_left_err=e
                data_right_err=e
    elif plot_what == "metrology":
        mjds = fitsfile.met_mjd.astype(datetime.datetime)
        data_left  = fitsfile.met_field_fiber_dx
        data_right = fitsfile.met_field_fiber_dy
        data_left_err  = None
        data_right_err = None
    elif plot_what == "metscatter":
        mjds = fitsfile.met_mjd.astype(datetime.datetime)
        dxmed = np.median(fitsfile.met_field_fiber_dx)
        dymed = np.median(fitsfile.met_field_fiber_dy)
        filtermet = np.logical_and(np.abs(fitsfile.met_field_fiber_dx-dxmed) > 3,
                                   np.abs(fitsfile.met_field_fiber_dy-dymed) > 3)
        mjdmasked=np.ma.masked_where(filtermet, fitsfile.met_mjd.astype(datetime.datetime))
        dxmasked = np.ma.masked_where(filtermet, fitsfile.met_field_fiber_dx)
        dymasked = np.ma.masked_where(filtermet, fitsfile.met_field_fiber_dy)
        mjds = np.reshape(np.average(fitsfile.met_mjd, weights=1.0-filtermet), (1,))
        data_left  = np.reshape(np.ma.sqrt(np.ma.average((dxmasked-dxmed)**2, axis=0)), (1, 4))
        data_right = np.reshape(np.ma.sqrt(np.ma.average((dymasked-dymed)**2, axis=0)), (1, 4))
        data_left_err  = None
        data_right_err = None
    elif plot_what == "scalerot":
        field_scale_masked=np.ma.masked_where( filter_out, field_scale)
        field_rot_error_masked=np.ma.masked_where( filter_out, field_rot_error)
        data_right=field_scale_masked
        data_left=field_rot_error_masked
        if have_errors:
            data_right_err=(np.power(field_scale/fitsfile.rho_in, 3)*fitsfile.rho_in *
                            (np.abs(field_ft_sc_x)*field_ft_sc_ex +
                             np.abs(field_ft_sc_y)*field_ft_sc_ey))

            data_left_err=(180./np.pi *
                           (np.abs(field_ft_sc_x)*field_ft_sc_ey +
                            np.abs(field_ft_sc_y)*field_ft_sc_ex) /
                           (field_ft_sc_x*field_ft_sc_x +
                            field_ft_sc_y*field_ft_sc_y))
    else:
        raise ValueError

    return (mjds, data_left, data_right, data_left_err, data_right_err)


def plot_file(fitsfile, ax1, ax2, use_labels, markers, labeled_markers, plot_what, err_threshold, gv_port):

    mjds, data_left, data_right, data_left_err, data_right_err = compute_data(fitsfile, plot_what, err_threshold)

    have_errors = data_left_err is not None and data_right_err is not None
    colors=['b', 'g', 'r', 'c']

    for p in gv_port:
        p -= 1
        if use_labels:
            port_label="GV"+str(p+1)
        else:
            port_label=None
        if p==3 and fitsfile.key not in labeled_markers:
            obj_label=fitsfile.key
            labeled_markers.add(fitsfile.key)
        else:
            obj_label=None
        if have_errors:
            ax1.errorbar(mjds, data_left[:,p], fmt="none", yerr=data_left_err[:,p], ecolor=colors[p], label=port_label)
            ax2.errorbar(mjds, data_right[:,p], fmt="none", yerr=data_right_err[:,p], ecolor=colors[p], label=port_label)
        ax1.plot(mjds,data_left[:,p],  colors[p]+markers[fitsfile.key], label=obj_label)
        ax2.plot(mjds,data_right[:,p],  colors[p]+markers[fitsfile.key], label=obj_label)

def main():
    parser = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     prefix_chars='-+')
    parser.add_argument('fname', nargs='*')

    parser.add_argument("-r", "--rho", help="separation of binary (mas) [from FITS header]. May be "
                        "prefixed with FT and SC taget names and used several times: --rho=IRS16C:S2:1200. "
                        "The prefixed values are used for the specific pair (or the reverse) while a value "
                        "without prefix is a global default.",
                        type=str, action=parser_actions.Rho)
    parser.add_argument("-f", "+f", "--first_offsets", "--no-first_offsets",
                        action=parser_actions.Toggle, nargs=0, default=True, help=
                        "trust the first offsets found for each FT/SC pair. "
                        "This is useful if the first file for a pair "
                        "corresponds to the real position while other files "
                        "were taken at non-zero offset in the exposure template. "
                        "Use +f or --no-first-offsets to disable. [Enabled]")
    parser.add_argument("-k", "--keywords_equal", help=
                        "keywords should exist and match the value. May be "
                        "repeated. "
                        "[--keywords_equal='ESO DPR TYPE:OBJECT,DUAL' "
                        "--keywords_equal='ESO DPR CATG:SCIENCE']",
                        action=parser_actions.Keyword, type=str)
    parser.add_argument("-x", "--exclude_object", action=parser_actions.Object, type=str,
                        help="exclude FT:SC doublet. "
                        "[-xCalibration_unit_fiber_1:Calibration_unit_fiber_2]")
    parser.add_argument("-i", "--include_object", action=parser_actions.Object, type=str,
                        help="include FT:SC doublet. May be specified more "
                        "than once. If specified, all other doublets will be "
                        "excluded. --exclude_object has precedence over "
                        "--include_object. [None]")
    parser.add_argument("-p", "--plot_what", help="what to plot. "
                        "\"offsets\": SC - FT offsets, "
                        "[\"scalerot\"]: rotation errors and plate scales, "
                        "FToffset: offset from FT fiber to FT source, "
                        "SCoffset: offset from SC fiber to SC source, "
                        "FTscatter: scatter of FT fiber within 'group' frames, "
                        "SCscatter: scatter of SC fiber within 'group' frames, "
                        "differential: (FIELD_SC-FIELD_FT)-(FIBER_SC-FIBER_FT), "
                        "skydifferential: same as above projected on sky, "
                        "diffscatter: scatter of FIELD_FIBER_DX/Y within 'group' frames, "
                        "metscatter: scatter of MET_FIELD_FIBER_DX/Y within file, "
                        "metrology: field_fiber_dx and dy based on metrology",
                        type=str, default="scalerot",
                        choices=['scalerot', 'offsets', 'FToffset', 'SCoffset',
                                 'FTscatter', 'SCscatter', 'differential',
                                 'diffscatter', 'skydifferential', 'metrology',
                                 'metscatter'])
    parser.add_argument("-e", "--err_threshold", help=
                        "error threshold. If the uncertainty of SCX, SC_Y, "
                        "FT_X or FT_Y is above this threshold, then the point is not plotted.",
                        type=float, default=0.2)
    parser.add_argument("-G", "--gv_port", help=
                        "GRAVITY ports to plot. Default: all.",
                        type=int, action='append', choices=[1, 2, 3, 4])
    parser.add_argument("-g", "--group", help="number of frames to average together [DET1.NDIT/4]", type=int)
    parser.add_argument("-o", "--output", help="file name to save ASCII data.",
                        type=str, default="")

    args = parser.parse_args()

    if args.gv_port is None:
        args.gv_port=[1, 2, 3, 4]

    if args.keywords_equal is None:
        args.keywords_equal = {"ESO PRO CATG": "DUAL_SCI_P2VMRED"}

    if args.exclude_object is None:
        args.exclude_object =set(
            ["Calibration_unit_fiber_1:Calibration_unit_fiber_2"])

    # If file list is emtpy, use "./GRAVI*.fits"
    if len(args.fname) is 0:
        args.fname=sorted(glob.glob("./GRAVI*.fits"))

    # Filter file list
    def filter_files(fn):
        '''True if fn is the name of a suitable file'''
        # Remove files that can't be opened as FITS
        try:
            header=fits.open(fn)[0].header
        except (IOError):
            try:
                header=fits.open("./"+fn)[0].header
            except (IOError):
                return False
        # Remove files that don't contain FT and SC name
        try:
            FTname=header["HIERARCH ESO FT ROBJ NAME"]
            SCname=header["HIERARCH ESO INS SOBJ NAME"]
            pair=FTname.replace(':', '::') + ':' + SCname.replace(':', '::')
        except KeyError:
            return False
        # Exclude as per exclude_object
        if args.exclude_object is not None:
            if pair in args.exclude_object:
                return False
        # If include_object has been used, exclude all other objects.
        if args.include_object is not None:
            if pair not in args.include_object:
                return False
        # Check keywords that must match
        for key in args.keywords_equal:
            try:
                if header[key] != args.keywords_equal[key]:
                    return False
            except KeyError:
                return False
        # All good, valid file
        return True

    args.fname=list(filter(filter_files, args.fname))

    if len(args.fname) == 0:
        raise RuntimeError("No valid file found!")

    fig=plt.figure(figsize=(16, 5))
    ax1=plt.subplot(1, 2, 1)
    if args.plot_what == "offsets":
        plt.title("SC_X - FT_X")
    elif args.plot_what == "scalerot":
        plt.title("Error on rotation")
    elif args.plot_what == "FToffset":
        plt.title("FT_X - FIBER.FT#X")
    elif args.plot_what == "SCoffset":
        plt.title("SC_X - FIBER.SC#X")
    elif args.plot_what == "FTscatter":
        plt.title("FT_X scatter")
    elif args.plot_what == "SCscatter":
        plt.title("SC_X scatter")
    elif args.plot_what == "diffscatter":
        plt.title("FIELD_FIBER_DX scatter")
    elif args.plot_what == "metscatter":
        plt.title("MET_FIELD_FIBER_DX scatter")
    elif args.plot_what == "differential":
        plt.title("dx from fibre to science incl. map offset (pix)");
    elif args.plot_what == "skydifferential":
        plt.title("RA offset from fibre to science incl. map offset (mas)");
    elif args.plot_what == "metrology":
        plt.title("dx from fibre to science incl. map offset (pix)");
    else:
        raise ValueError
    ax2=plt.subplot(1, 2, 2, sharex=ax1)
    if args.plot_what == "offsets":
        plt.title("SC_Y - FT_Y")
    elif args.plot_what == "scalerot":
        plt.title("Plate scales")
    elif args.plot_what == "FToffset":
        plt.title("FT_Y - FIBER.FT#Y")
    elif args.plot_what == "SCoffset":
        plt.title("SC_Y - FIBER.SC#Y")
    elif args.plot_what == "FTscatter":
        plt.title("FT_X scatter")
    elif args.plot_what == "SCscatter":
        plt.title("SC_X scatter")
    elif args.plot_what == "diffscatter":
        plt.title("FIELD_FIBER_DY scatter")
    elif args.plot_what == "metscatter":
        plt.title("MET_FIELD_FIBER_DY scatter")
    elif args.plot_what == "differential":
        plt.title("dy from fibre to science incl. map offset (pix)");
    elif args.plot_what == "skydifferential":
        plt.title("Dec offset from fibre to science incl. map offset (mas)");
    elif args.plot_what == "metrology":
        plt.title("dy from fibre to science incl. map offset (pix)");
    else:
        raise ValueError
    use_labels=True
    markers=Markers()
    labeled_markers=set()
    rho=args.rho
    if rho is None:
        rho=dict()
    if args.first_offsets:
        args.first_offsets = dict()
    else:
        args.first_offsets = None

    if args.output != "":
        f=open(args.output, 'w')

    for fname in args.fname:
        fitsfile=io.File(fname, rho=rho, first_offsets=args.first_offsets, plot_fit=False, group=args.group)

        if args.output != "":
            mjds, data_left, data_right, data_left_err, data_right_err = compute_data(fitsfile, args.plot_what, args.err_threshold)
            for i in range(0, mjds.size):
                try:
                    mjd=mjds[i].isoformat()
                except AttributeError:
                    mjd=mjds[i]
                f.write('{} {} {} {} {} {} {} {} {}\n'.
                        format(
                            mjd,
                            data_left[i][0],
                            data_left[i][1],
                            data_left[i][2],
                            data_left[i][3],
                            data_right[i][0],
                            data_right[i][1],
                            data_right[i][2],
                            data_right[i][3]
                        ))


        plot_file(fitsfile, ax1, ax2, use_labels, markers, labeled_markers,
                  args.plot_what, args.err_threshold, args.gv_port)
        use_labels=False

    if args.output != "":
        f.close()

    plt.legend(numpoints=1, fontsize='x-small', fancybox=True, framealpha=0.5)
    ax1.fmt_xdata = mdates.DateFormatter('%Y:%m:%dT%H:%M:%S.%f', tz=tzutc)
    lims=np.array(ax1.axis())
    dx=lims[1]-lims[0]
    lims[0] -= 0.1*dx
    lims[1] += 0.1*dx
    ax1.axis(lims)
    try:
        fig.autofmt_xdate()
        plt.show()
    except ValueError:
        raise ValueError("'"+__name__+"': Feels like there's nothing to plot. Try increasing err_threshold.")

if (__name__ == "__main__"):
    main()
