import numpy as np
from matplotlib import pyplot
from . import astro

def orbit(t, param, Vrad=False, verbose=False):
    """
    INPUT PARAMETERS:

    t: list of times (MJD)
    param :  dictionnary containing the parameters:
    param = {'a': 99.1, 'i': 32.9, 'OMEGA': 172.8, 'e':0.938,
             'omega':2.1, 'T0': 2451798.0280-2400000.5, 'P':10.817}

    Vrad=True: returns also the radial velocity

    RETURNS:

    3 (or 4) ndarray: X,Y,Z (and Vrad)

    DETAILS:

    alternatively, if 'M' the total mass is given:

    - if 'P' is missing, it is estimated using the Kepler law assuming
      the semi-major axis 'a' is in AU (i.e. apparent corrected from
      distance). It also assumes 'P' is in days.

    - if 'a' is missing it is estimated using the Kepler law and 'P'
      (assumed to be in years). Results then given in AU

    all angles in degress. T0 and P should have same unit (here days).

    result is xyz where x is RA offset and y is dec offset, in units
    of 'a'.

    also, 'q' can be given (M1/M2) for the computation of the radial
    velocity (in units of a/P). Otherwise, the function only return
    dz/dt. Using this definition, radial velocity is positive toward
    the observer. By default, radial velocities are not computed
    (takes twice as long).

    definition of semi-amplitude:
    K(1,2) = 2*pi*a(1,2)*sin(i)/(P(1-e**2)**(0.5))

    hence:
    a(1,2)sini = (1-e**2)**(0.5)/(2*np.pi)*K(1,2)*P

    refs:
    http://en.wikipedia.org/wiki/Mean_anomaly
    http://en.wikipedia.org/wiki/Eccentric_anomaly
    http://en.wikipedia.org/wiki/True_anomaly
    """
    if 'P' not in param and 'M' in param:
        param['P'] = astro.Kepler3rdLaw(a=param['a'],
                                        M1M2=param['M'])
    # -- force a based on P and M
    if 'P' in param and 'M' in param:
        _a = astro.Kepler3rdLaw(P=param['P'],
                                        M1M2=param['M'])*\
                            astro.SI.Rsol/astro.SI.AU
        if verbose:
            print('P =', param['P'], 'days')
            print('M =', param['M'], 'Msol')
            print('A ->', _a , 'AU')
    else:
        _a = param['a']

    if 'a' not in param and 'K' in param:
        # a*sin(i) actually
        # unit???
        param['a'] = (1-param['e']**2)**(0.5)/(2*np.pi)*param['K'](1,2)*np.abs(param['P'])

    # The mean anomaly is the time since the last periapsis multiplied by the
    # mean motion, and the mean motion is 2\pi divided by the duration of a full
    # orbit.
    mean_anomaly = ((np.array(t)-param['T0'])%param['P'])/param['P']*2*np.pi

    #The eccentric anomaly E is related to the mean anomaly M by the formula:
    # M = E - e sin E
    ecc_ = np.linspace(0,2*np.pi, 1e5)
    ecc_anomaly = np.interp(mean_anomaly, ecc_ -param['e']*np.sin(ecc_), ecc_)

    # the true anomaly is an angular parameter that defines the position of a
    # body moving along a Keplerian orbit. It is the angle between the direction
    # of periapsis and the current position of the body, as seen from the main
    # focus of the ellipse (the point around which the object orbits).
    if np.abs(param['e']) < 1-1e-4:
        tmp = np.sqrt((1+np.abs(param['e']))/
                      (1-np.abs(param['e'])))
    else:
        tmp = 1.0
    true_anomaly = 2*np.arctan(tmp*np.tan(ecc_anomaly/2.))

    separation = (1-param['e']**2)/\
                 (1+np.abs(param['e'])*np.cos(true_anomaly))

    separation *= _a

    xyz = (separation*np.cos(true_anomaly),
           separation*np.sin(true_anomaly),
           0.0*separation)

    # --- SEE: http://commons.wikimedia.org/wiki/File:Orbital_elements.svg
    #          * P2 is the sky plan
    #          * Upsilon (vernal point) is axis 'X' on the sky
    if 'omega0' in param and 'domega' in param:
        omega = param['omega0']+(t-param['T0'])*param['domega']
        xyz = (xyz[0]*np.cos(omega*np.pi/180) -
               xyz[1]*np.sin(omega*np.pi/180),
               xyz[0]*np.sin(omega*np.pi/180) +
               xyz[1]*np.cos(omega*np.pi/180),
               xyz[2])
    else:
        xyz = (xyz[0]*np.cos(param['omega']*np.pi/180) -
               xyz[1]*np.sin(param['omega']*np.pi/180),
               xyz[0]*np.sin(param['omega']*np.pi/180) +
               xyz[1]*np.cos(param['omega']*np.pi/180),
               xyz[2])

    if 'i' in param:
        xyz = (xyz[0],
               xyz[1]*np.cos(param['i']*np.pi/180) +
               xyz[2]*np.sin(param['i']*np.pi/180),
               -xyz[1]*np.sin(param['i']*np.pi/180) +
               xyz[2]*np.cos(param['i']*np.pi/180))

    if 'OMEGA' in param:
        xyz = (xyz[0]*np.cos(param['OMEGA']*np.pi/180-np.pi/2) +
               xyz[1]*np.sin(param['OMEGA']*np.pi/180-np.pi/2),
              -xyz[0]*np.sin(param['OMEGA']*np.pi/180-np.pi/2) +
               xyz[1]*np.cos(param['OMEGA']*np.pi/180-np.pi/2),
               xyz[2])
    
    # -- computing radial velocity (V_B - V_A)
    if Vrad:
        dt = param['P']*1e-5
        xyz_dt = orbit(np.array(t)+dt, param, Vrad=False)
        vrad = (xyz_dt[2]-xyz[2])/dt # -- in units of a/P
        if 'plx' in list(param.keys()):
            vrad /= param['plx'] # -- in units of a/P/plx
        # -- returns X, Y, Z and Vrad in units of [a, a, a, a/P/plx]
        xyz = (xyz[0], xyz[1], xyz[2], vrad)

    return xyz
