import numpy as np
import pylab as plt
from astropy.io import fits
import astropy._erfa as erfa


def dtp2s(xi, eta, raz, decz):
    sdecz = np.sin(decz)
    cdecz = np.cos(decz)
    denom = cdecz - eta * sdecz
    d = np.arctan2(xi, denom) + raz
    ra = d % (2.0*np.pi)
    dec = np.arctan2(sdecz + eta * cdecz, np.sqrt(xi * xi + denom * denom))
    return ra, dec


def angle2str(angle):
    sign, idmsf = erfa.a2af(6, angle)
    return "{0}{1:02}:{2:02}:{3:02}.{4:06}".format(sign.flatten()[0][:1].decode()[0], idmsf[0], idmsf[1], idmsf[2], idmsf[3])


def rastr2rad(ra):
    # Convert to float
    val = float(ra)
    # Extract sign
    if val >= 0:
        s = '+'
    else:
        s = '-'
        val *= -1.0
    ihour = int(val/10000.0)
    val -= ihour*10000.0
    iamin = int(val/100.0)
    val -= iamin*100.0
    asec = val
    # print("{0} -> {1} {2:02} {3:02} {4}".format(ra, s, ihour, iamin, asec))
    return erfa.tf2a(s, ihour, iamin, asec)


def decstr2rad(dec):
    # Convert to float
    val = float(dec)
    # Extract sign
    if val >= 0:
        s = '+'
    else:
        s = '-'
        val *= -1.0
    ideg = int(val/10000.0)
    val -= ideg*10000.0
    iamin = int(val/100.0)
    val -= iamin*100.0
    asec = val
    # print("{0} -> {1} {2:02} {3:02} {4}".format(dec, s, ideg, iamin, asec))
    return erfa.af2a(s, ideg, iamin, asec)


def atcoq(rc, dc, pr, pd, px, rv, astrom):
    ri, di = erfa.atciq(rc, dc, pr, pd, px, rv, astrom)
    aob, zob, hob, dob, rob = erfa.atioq(ri, di, astrom)
    return aob, zob, hob, dob, rob


def normalize(x):
    return x/np.sqrt(np.sum(x**2,axis=1))[:,None]


# Convert azimuth and zenith angle to ENU direction cosines
#
#  a: azimuth in radians (N: 0, E:90)
#  z: zenith angle in radians
#
def az2enu(a, z):
    e = np.sin(z)*np.sin(a)
    n = np.sin(z)*np.cos(a)
    u = np.cos(z)
    return np.array([e,n,u]).T


def wsu2enu(wsu):
    return wsu*np.array([-1.0,-1.0,1.0])


# Compute UVW at given RA/DEC position
#
# This is done by calculating the derivatives around provided RA/DEC
# The UVW vectors are expressed in the local ENU frame
#
#  rc0: right ascension in radians
#  dc0: declination in radians
#
def UVW(rc0, dc0, astrom):
    # Derivative angle, adjusted for optimal accuracy
    eps = 10.0  # arcsec
    # Compute the derivative RA/DEC
    rcUp, dcUp = dtp2s(np.deg2rad(+eps/3600.0), 0.0, rc0, dc0)
    rcUm, dcUm = dtp2s(np.deg2rad(-eps/3600.0), 0.0, rc0, dc0)
    rcVp, dcVp = dtp2s(0.0, np.deg2rad(+eps/3600.0), rc0, dc0)
    rcVm, dcVm = dtp2s(0.0, np.deg2rad(-eps/3600.0), rc0, dc0)
    # Convert into derivative azimuth / zenith distance
    aob0, zob0, _, _, _ = atcoq(rc0, dc0, 0.0, 0.0, 0.0, 0.0, astrom)
    aobUp, zobUp, _, _, _ = atcoq(rcUp, dcUp, 0.0, 0.0, 0.0, 0.0, astrom)
    aobUm, zobUm, _, _, _ = atcoq(rcUm, dcUm, 0.0, 0.0, 0.0, 0.0, astrom)
    aobVp, zobVp, _, _, _ = atcoq(rcVp, dcVp, 0.0, 0.0, 0.0, 0.0, astrom)
    aobVm, zobVm, _, _, _ = atcoq(rcVm, dcVm, 0.0, 0.0, 0.0, 0.0, astrom)
    # Compute UVW
    eU = az2enu(aobUp, zobUp)-az2enu(aobUm, zobUm)
    eV = az2enu(aobVp, zobVp)-az2enu(aobVm, zobVm)
    eW = -az2enu(aob0, zob0)
    return normalize(eU), normalize(eV), normalize(eW)


# Compute UVW at given RA/DEC position
#
# This is done by calculating the derivatives around provided RA/DEC
# The UVW vectors are expressed in the local ENU frame
#
#  rc0: right ascension in radians
#  dc0: declination in radians
#
def UVW2(rc0, dc0, pr, pd, astrom):
    ri0, di0 = erfa.atciq(rc0, dc0, pr, pd, 0.0, 0.0, astrom)
    # Derivative angle, adjusted for optimal accuracy
    eps = 10.0  # arcsec
    # Compute the derivative RA/DEC
    riUp, diUp = dtp2s(np.deg2rad(+eps/3600.0), 0.0, ri0, di0)
    riUm, diUm = dtp2s(np.deg2rad(-eps/3600.0), 0.0, ri0, di0)
    riVp, diVp = dtp2s(0.0, np.deg2rad(+eps/3600.0), ri0, di0)
    riVm, diVm = dtp2s(0.0, np.deg2rad(-eps/3600.0), ri0, di0)
    # Convert into derivative azimuth / zenith distance
    aob0, zob0, _, _, _ = erfa.atioq(ri0, di0, astrom)
    aobUp, zobUp, _, _, _ = erfa.atioq(riUp, diUp, astrom)
    aobUm, zobUm, _, _, _ = erfa.atioq(riUm, diUm, astrom)
    aobVp, zobVp, _, _, _ = erfa.atioq(riVp, diVp, astrom)
    aobVm, zobVm, _, _, _ = erfa.atioq(riVm, diVm, astrom)
    # Compute UVW
    eU = az2enu(aobUp, zobUp)-az2enu(aobUm, zobUm)
    eV = az2enu(aobVp, zobVp)-az2enu(aobVm, zobVm)
    eW = -az2enu(aob0, zob0)
    return normalize(eU), normalize(eV), normalize(eW)


if __name__ == '__main__':
    with fits.open("/Users/jwoillez/Documents/Work/workspace/GRAVITY/data/pipeline/2016-01-19/reduced/GRAVI.2016-01-20T08:25:11.461_p2vmreduced.fits") as hdulist:
        utc1 = 2400000.5
        utc2 = hdulist[9].data['MJD']
        N = len(utc2)
        rc = rastr2rad(hdulist[0].header["ESO FT ROBJ ALPHA"])
        dc = decstr2rad(hdulist[0].header["ESO FT ROBJ DELTA"])
        pr = 0.0
        pd = 0.0
        px = 0.0
        rv = 0.0
        dut1 = 0.0
        elong = np.deg2rad(-70.40498688)
        phi = np.deg2rad(-24.62743941)
        hm = 2681.0
        xp = 0.0
        yp = 0.0
        phpa = 0.0
        tc = 0.0  # degC
        rh = 0.0  # 0-100
        wl = 0.0  # um
        # Prepare celestial to Observe transform
        astrom0, eo0 = erfa.apco13(utc1, utc2[N//2], dut1, elong, phi, hm, xp, yp, phpa, tc, rh, wl)
        ut11, ut12 = erfa.utcut1(utc1, utc2, dut1)
        astrom = erfa.aper13(ut11, ut12, astrom0)
        era = erfa.era00(ut11, ut12)
        lst = era + elong
        # Create observed UVW reference frame
        eU, eV, eW = UVW2(rc, dc, pr, pd, astrom)
        # Compute baselines
        station2xyz = dict(list(zip(hdulist['OI_ARRAY'].data['STA_INDEX'], hdulist['OI_ARRAY'].data['STAXYZ'])))
        B = np.array([wsu2enu(station2xyz[index[1]]-station2xyz[index[0]]) for index in hdulist[9].data['STA_INDEX']])
        # Compute UVW
        U = np.sum(B*eU, axis=1)
        V = np.sum(B*eV, axis=1)
        W = np.sum(B*eW, axis=1)
        print((U[:6]))
        print((hdulist[9].data['UCOORD'][:6]))

    plt.show()
