'''Compute separation and position angle of a binary

This script is thought to work with the orbital parameters listed in
the ORB6 catalog.

The 't' argument gives the observing date as a string in any format
automatically recognized by astropy.time.

Example for date strings:
 B2020.0                   Besselian years (same as 'y' un ORB6)
 J2020.0                   Julian years
 2020-01-01 calendar date
 "2020-01-01 12:00:00.000" ISO format date
 2020-01-01T12:00:00.000   ISOT (FITS) format date
 2020:001:12:00:00.000     year day time

The 'orbital_elements' argument may be:
  - one full line from the ORB6 catalog (as one quoted string);
  - a substring thereof, used as an identifier to query this catalog;
  - 11 individual tokens yielding the orbital elements:
    ID P uP a ua i Omega T uT e omega (See below).

Examples for orbital_elements:
 a WDS ID:
  01388-1758
 a bit more than the WDS ID:
 '07351+3058 STT 175AB'
 a larger fraction of a row in the ORB6 catalog:
  "00114+5850 SKW   1Aa,Ab   .     .      .       15.    19.      462.7"
 actual orbital parameters:
  "BY Dra" 143.4 h 4.4 m 154. 152. 53999.2144 m 0.3 230.

If 'orbital_elements' is an ID, the script will default to downloading
the ORB6 catalog from its canonical URL. You may use the --catalog
option to use a local file or alternate URL instead. Using a local
file is faster than downloading the catalog from the web each time. If
the ID you use in ambiguous, the ephemerids for all matching rows will
be computed.

If 11 tokens are given as orbital elements, they must be as follows:
 ID: free form ID (for printing);
 P:  orbital period
 uP: unit for P: m=minutes, h=hours, d=days, y=years, c=centuries
 a:  semi-major axis
 ua: unit for a: a=arcseconds, m=mas, M=arcminutes, u=microarcseconds
 i:  inclination in degrees
 Omega: position angle of the line of nodes in degrees
 T:  time of periastron passage 
 uT: unit for T: c=centuries (fractional year / 100), d=truncated
                 Julian date (JD-2,400,000 days), m=modified Julian
                 date (MJD = JD-2,400,000.5 days), y=fractional
                 Besselian year
 e:  eccentricity
 omega: longitude of periastron in degrees

'''

import numpy
import datetime
import astropy.time
import argparse
import requests


default_catalog='http://ad.usno.navy.mil/wds/orb6/orb6orbits.txt'

parser = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)

parser.add_argument('t', help="date, e.g. 2019-10-15 or 2019-10-15T12:00:55.2'")
parser.add_argument('orbital_elements', nargs='+',
                    help="orbital elements specifier (either one string or 11 tokens)")
parser.add_argument('--catalog', '-c', default=default_catalog,
                    help=f"if orbital_elements is an ID, where to find the catalog? May be a local file or HTTP URL (default: {default_catalog})")

def todays(P, uP):
    '''convert period to days

    input:
    P: some duration
    uP: unit as per ORB6
    '''
    if uP == 'm':   # minutes
        P /= (24.*60.)
    elif uP == 'h': # hours
        P /= 24.
    elif uP == 'd': # days
        pass
    elif uP == 'y': # years
        P *= 365.242198781
    elif uP == 'c': # centuries
        P *= 365.242198781*100.
    return P

def toMJD(T, uT):
    '''convert date to MJD

    input:
    T: some date
    uT: unit as per ORB6
    '''
    if uT == 'd':   # Julian date (-2,400,000 days)
        T -= 0.5
    elif uT == 'm': # modified Julian date (MJD = JD-2,400,000.5 days)
        pass
    elif uT == 'y': # fractional Besselian year
        T= (T-1900)*365.242198781 + 2415020.31352 - 2400000.5
    elif uT == 'c': # "centuries" (year/100)
        T= (T*100.-1900)*365.242198781 + 2415020.31352 - 2400000.5
    else:
        raise ValueError('unknown unit for T: '+uT)
    return T

def ephemerids(t, orbital_elements, uT=None, uP=None, ua=None, format=None, catalog=default_catalog):
    '''rho, theta = ephemerids(t, orbital_parameters)
    inputs:
    t: dates for which ephemerids are desired
    orbital_parameters: either an array-like listing
         (P, a, i, Omega, T, e, omega) (see below) or
         a string containing one row from the ORB6 catalog:
         http://ad.usno.navy.mil/wds/orb6/orb6orbits.txt

    orbital parameters:
    P: period
    a: semi major axis
    i: orbital inclination in degrees
    Omega: position angle of ascending node in degrees
    T: date of passage through periastron
    e: excentricity
    omega: argument of periastron

    t, P and T must be given in compatible units : e.g. all in MJD, or
    all in years.

    This function will parse the string, print out the orbital
    elements, and proceed will the computation.

    output:
   
    (rho, theta) where rho and theta have the same shape as t. rho is
    the separation in the same unit as parameter `a'; theta is position
    angle of the line of nodes, east of north.
   
    ORB6 elements:
   
    http://ad.usno.navy.mil/wds/orb6/format.html
    PPPP.PPPPPP: P
    AAA.AAAAA:   a
    III.IIII:    i
    NNN.NNNN:    Omega
    TTTTT.TTTTTT:T
    E.EEEEEE:    e
    OOO.OOOO:    omega
    '''
    if len(orbital_elements) is 1:
        if len(orbital_elements[0]) < 223:
            try:
                answer = requests.get(catalog)
                answer.raise_for_status()
                catalog_rows = answer.text.splitlines()
            except requests.exceptions.MissingSchema:
                catalog_rows = open(catalog).readlines()
            rows = [row for row in catalog_rows if orbital_elements[0] in row]
            if len(rows) is 0:
                raise KeyError(f"Target spec '{orbital_elements}' not found in ORB6 catalog")
            else:
                pass
        # radec=orbital_elements[0:18]
        # WDS=orbital_elements[19:29]
        # DD=orbital_elements[30:44]
        # ADS=orbital_elements[45:50]
        # HD=orbital_elements[51:57]
        # HIP=orbital_elements[58:65]
        # V1=orbital_elements[66:71]
        # V1flag=orbital_elements[71:72]
        # V2=orbital_elements[73:78]
        #V2flag=orbital_elements[78:79]
        orbital_elements=[
            (orbital_elements[19:65],
             float(orbital_elements[79:92]),   # P yes, field starts at 79 on some rows
             orbital_elements[92:93], # uP
             # eP=orbital_elements[93:104].strip(),#yes, field starts at 93 on some rows
             float(orbital_elements[105:114]),# a
             orbital_elements[114:115], #ua
             #ea=orbital_elements[115:125].strip(),
             float(orbital_elements[125:134]), #i
             # ei=orbital_elements[134:143].strip()
             float(orbital_elements[143:151]), # Omega
             # Omegaflag=orbital_elements[151:152]
             # eOmega=orbital_elements[152:162].strip()
             float(orbital_elements[162:174]), # T=
             orbital_elements[174:175], # uT=
             # eT=orbital_elements[175:187].strip()
             float(orbital_elements[187:196]), # e=
             # ee=orbital_elements[196:205].strip()
             float(orbital_elements[205:214]) # omega=
             #eomega=orbital_elements[214:223].strip()
             # EQNX=orbital_elements[223:227]
             # LAST=orbital_elements[228:232]
             # etc.
            ) for orbital_elements in rows]
    else:
        ID, P, uP, a, ua, i, Omega, T, uT, e, omega = orbital_elements
        orbital_elements=[(ID,
                           float(P), uP,
                           float(a), ua,
                           float(i),
                           float(Omega),
                           float(T), uT,
                           float(e),
                           float(omega))]
        
    # Support several time formats
    # Support a single value (or an array-like)
    if isinstance(t, astropy.time.Time):
        isscalar=t.isscalar
        if isscalar:
            t=astropy.time.Time([t])
    else:
        isscalar=numpy.isscalar(t)
        if isscalar:
            t=[t]
        t=numpy.asarray(t)

    answers = list()
    for ID, P, uP, a, ua, i, Omega, T, uT, e, omega in orbital_elements:
        # convert str and datetime to astropy.time.Time
        if (isinstance (t[0], str) or
            isinstance (t[0], datetime.datetime)):
            t = astropy.time.Time(t, format=format)

        # convert astropy.time.Time back to floats (MJD)
        if isinstance (t, astropy.time.Time):
            # then convert to MJD and make it a numpy array
            t=t.mjd

            # check unit (according to ORB6 definitions, convert T and P to MJD
            if uT is None or uP is None:
                raise ValueError("Need to know uP and uT when t is not a float")
            P=todays(P, uP)
            uP='d'
            T=toMJD(T, uT)
            uT='m'
        # support numerical value
        elif numpy.isreal(t[0]):
            if uT is None:
                pass
            else: # convert everything to mjd
                if uP is None:
                    raise NotImplementedError("uT is not None but uP is None")
                # check unit (according to ORB6 definitions, convert T and P to MJD
                P=todays(P, uP)
                uP='d'
                T=toMJD(T, uT)
                t=toMJD(t, uT)
                uT='m'

        else:
                raise TypeError('t should be a(n array of) float, str, datetime or astropy.time.Time')

        mu=2.*numpy.pi/P;

        Omega *= numpy.pi/180.;
        i *= numpy.pi/180.;
        omega *= numpy.pi/180.;

        # Mean anomaly in radians
        M=mu*(t-T);

        # Eccentric anomaly in radians
        E=numpy.zeros(t.shape)
        Mc=numpy.zeros(t.shape)
        ind = numpy.where(M != 0.);
        if len(ind):
            E[ind] = M[ind] + e * numpy.sin(M[ind]) + e**2 * numpy.sin(2.*M[ind])
            for c in range(1000):
                if numpy.max(numpy.abs(Mc -M)) < 1e-7:
                    break
                Mc[ind] = E[ind] - e * numpy.sin(E[ind])
                E[ind]  = E[ind] + (M[ind]-Mc[ind]) / (1. - e * numpy.cos(E[ind]))

        tan_nu_over_2 = numpy.sqrt((1.+e)/(1.-e)) * numpy.tan (0.5 * E)
        nu = 2.*numpy.arctan(tan_nu_over_2);

        r = a * (1-e**2)/(1.+e*numpy.cos(nu));
        y = numpy.sin(nu+omega) * numpy.cos(i);
        x = numpy.cos(nu+omega);
        theta = Omega + numpy.arctan2(y, x);
        rho=r*numpy.sqrt(x**2+y**2);

        if isscalar:
            rho=rho[0]
            theta=theta[0]

        answers.append((ID, rho, theta*180./numpy.pi, ua))

    return answers

def main(args):
    answers=ephemerids(args.t, args.orbital_elements, catalog=args.catalog)
    for ID, rho, theta, ua in answers:
        if ua is None:
            unit="(unknown unit)"
        elif ua == 'a':
            unit='arcsec'
        elif ua == "m":
            unit='mas'
        elif ua == "M":
            unit="arcmin"
        elif ua == "u":
            unit="μas"
        print(f"{ID} rho={rho:g}{unit}, theta={theta}°")

