"""
This library contains physical and astronphysical constants (in SI),
as well as functions to do simple Stellar Physics computation.

Uses numpy for array operations

Some examples:

>>> import astro

Q: value and unit of the Gravitational constant?
A: >>> print astro.SI.G, astro.SI.G_unit
   6.67428e-11 m3.kg-1.s-2
   
Q: Energy of a photon, in eV, of wavelength 1 microns?
A: >>> astro.SI.h*astro.SI.c/1e-6/astro.SI.eV
   1.2398429452389037
   
Q: what is the semi-major axis, in solar radii, of a binary with a
   total mass of 3 solar masses and perdiod 100 days?
A: >>> astro.Kepler3rdLaw(P=100, M1M2=3)
   130.7971597440739

Q: what is the angular separation, in arcsec, of a binary composed of
   a G2V and F6IV, of period 300 days and with a parallax of 0.017
   arcsecond?
A: >>> m1 = ApproximateMassFromSpectralType('G2', lum_class='V')
   >>> m2 = ApproximateMassFromSpectralType('F6', lum_class='IV')
   >>> astro.EstimateApparentOrbitalSeparation(m1, m2, 300., 0.017)
   0.01972776788758886
   >>> # or directly:
   >>> astro.EstimateApparentOrbitalSeparation('G2V', 'F6IV', 300., 0.017)
   0.01972776788758886
"""

import numpy as np
from scipy import special

class SIconsts():
    """
    Physical and Astrophysical constants in SI units. All variables
    have a string defining their unit.

    REF: from Astrophysical Quantities and/or 'The call to adopt a
    nominal set of astrophysical parameters and constants to improve
    the accuracy of fundamental physical properties of stars.' from
    P. Harmanec [1106.1508v1.pdf in Astro-Ph]
    """
    def __init__(self):        
        self.c  =2.99792458e8
        self.c_unit = 'm.s-1'
        self.G  =6.67428e-11
        self.G_unit = 'm3.kg-1.s-2'
        self.h   = 6.626075e-34
        self.h_unit ='J.s'
        self.k   = 1.380658e-23
        self.k_unit ='J.K-1'        
        self.eV = 1.602176565e-19
        self.eV_unit = 'J.s-1'
        self.sigma = 5.670400e-8 
        self.sigma_unit = 'W.m-2.K-4'
        self.hc = self.h*self.c
        self.hc_unit ='J.m'
        self.yr = 365.25*24*3600
        self.yr_unit ='s'
        self.day = 24*3600
        self.day_unit ='s'        
        self.AU = 1.495978707e11
        self.AU_unit ='m'
        self.au = self.AU
        self.au_unit =self.AU_unit
        self.pc = 3.085677581503e16
        self.pc_unit ='m'
        self.erg  = 1e-7
        self.erg_unit ='J'       
        self.Jy   = 1.e-26
        self.Jy_unit ='J.s-1.m-2.Hz-1'
        self.Rsol = 6.95508e8
        self.Rsol_unit ='m'
        self.Dsol = 2*self.Rsol
        self.Dsol_unit = self.Rsol_unit
        self.Msol = 1.988419e30
        self.Msol_unit ='kg'
        self.Lsol = 3.846e33*self.erg
        self.Lsol_unit ='J.s-1'        
        self.Teffsol = (self.Lsol/(4*np.pi*self.Rsol**2*self.sigma))**(1/4.)
        self.Teffsol_unit = 'K'
        self.Mbolsol = 4.75
        self.Mbolsol_unit = 'mag'
        self.arcsec =  np.pi/180.*1/(3600)
        self.arcsec_unit ='rad'
        self.mas = self.arcsec/1000.
        self.mas_unit ='rad'

SI = SIconsts()

def Kepler3rdLaw(a=None, P=None, M1M2=None):
    """
    - a in solar radii (scalar or None)
    - P in days (scalar or None)
    - M1M2 = M1+M2 in solar masses (scalar or None)

    at least 2 of the 3 should be input, then the function returns the
    third one in the proper unit. If not, check if the 3 are
    consistent (boolean).

    REF: 'The call to adopt a nominal set of astrophysical parameters
    and constants to improve the accuracy of fundamental physical
    properties of stars.' from P. Harmanec [1106.1508v1.pdf in
    Astro-Ph]    
    """
    C = SI.G*SI.Msol*(86400)**2/(4*np.pi**2*SI.Rsol**3)
    eC = 0.0075
    if a is None:
        return (P**2*M1M2*C)**(1/3.)
    elif P is None:
        return (a**3/(M1M2*C))**(1/2.)
    elif M1M2 is None:
        return a**3/(P**2*C)
    else:
        return abs(a**3/(P**2*M1M2) - C)<=eC

def angDiamVK(V,K):
    """
    see Kervella et al. 2004
    """
    return 10**( 0.2753*(V-K) + 0.5175 - 0.2*V)

def SpectralTypeFromColor(col_val, color='V-K', inverse=False):
    """
    spectral type from color (default is color='V-K' ---values for
    dwarfs--- other possible value is color='B-V'). 

    if inverse=True (or if col_val is a string), will assume col_val
    is a spectral type and will return the color.
    """
    if isinstance(col_val, np.ndarray):
        col_val = list(col_val)
    if isinstance(col_val, list):
        return [SpectralTypeFromColor(x, color=color, inverse=inverse)
                for x in col_val]            
    if color=='V-K':
        st = ['B0', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9',
              'A0', 'A2', 'A5', 'A7',
              'F0', 'F2', 'F5', 'F7', 
              'G0', 'G2', 'G4', 'G6', 
              'K0', 'K2', 'K4', 'K5', 'K7', 
              'M0', 'M1', 'M2', 'M3', 'M4', 'M5', 'M6']        
        vk = [-0.83, -0.74, -0.66, -0.56, -0.49, -0.42, -0.36, -0.29,-0.24, -0.13,
              0.00, 0.14, 0.38, 0.50,
              0.70, 0.82, 1.10, 1.32,
              1.41, 1.46, 1.53, 1.64,
              1.96, 2.22, 2.63, 2.85, 3.16,
              3.65, 3.87, 4.11, 4.65, 5.28, 6.17, 7.37]
    elif color=='B-V':
        st = ['O5', 'O9',
              'B0', 'B2', 'B5', 'B8',
              'A0', 'A2', 'A5',
              'F0', 'F2', 'F5', 'F8',
              'G0', 'G2', 'G5', 'G8',
              'K0', 'K2', 'K5',
              'M0', 'M2', 'M5']
        vk = [-0.33, -0.31, -0.30, -0.24, -0.17, -0.11, -0.02, 0.05,
              0.15, 0.30, 0.35, 0.44, 0.52, 0.58, 0.63, 0.68,
              0.74, 0.81, 0.91, 1.15, 1.40, 1.49, 1.64]

    val = decimST(st) 
    # linear interpolation
    k = 0
    if inverse or isinstance(col_val, str):
        myval = decimST(col_val)
        return np.interp(myval, val, vk)    
    else:
        st = np.interp(col_val, vk, val)
        st_str = 'OBAFGKM'[int(st)]
        st_str += str(int(10*(st-int(st))))
        return st_str

def decimST(spectralType):
    """
    spectralType should be scalar (str) of at least one character in
    'OBAFGKM'

    returns a decimal value to code the spectral type. '05' is 0.5,
    'B2' is 1.2, 'A7' is 2.7 etc.
    """
    if not isinstance(spectralType, str):
        return [decimST(x) for x in spectralType]

    k = 0
    while not spectralType[k] in 'OBAFGKM':
        k = k+1
    if k>len(spectralType)-2:
        return 3    
    if spectralType[k]=='O':
        res = 0.
    elif spectralType[k]=='B':
        res = 1.
    elif spectralType[k]=='A':
        res = 2.
    elif spectralType[k]=='F':
        res = 3.
    elif spectralType[k]=='G':
        res = 4.
    elif spectralType[k]=='K':
        res = 5.
    elif spectralType[k]=='M':
        res = 6.
    try:
        res += float(spectralType[k+1])/10.
    except:
        res += 0.5 # default
    return res
    
def ParallacticRadius(ang_diam=None, plx=None):
    """
    return the parallactic radius in solar radii. the angular diameter
    and parallax should have same unit.
    """
    return SI.au/SI.Rsol/2.*ang_diam/plx

def LfromM_MS(m, calibrate=False):
    """
    Luminosity as a function of Mass, based on Mass Luminosity
    relation adapted from Torres et al. A&AR 16-67 (2010)
    """
    if calibrate:
        log10M=[-0.381138,-0.220536,-0.129384,-0.0816377,-0.0425725,0.035558,
                0.12237,0.174457,0.252587,0.343739,0.413188,0.534725,0.612855,
                0.677964,0.738732,0.851587,0.938399,1.02521,1.11636,1.20317,1.29867,
                1.37246,1.43322]
        log10L=[-1.83824,-1.13235,-0.75,-0.514706,-0.264706,
                0.191176,0.514706,0.735294,1.11765,1.54412,1.77941,2.10294,2.47059,
                2.79412,3.14706,3.45588,3.77941,4.02941,4.35294,4.64706,4.79412,
                5.02941,5.19118]

        print(list(np.polyfit(log10M-np.log10(5.),log10L, 2)))
        return    
    Cf = [-0.70132566059863344,
          3.7052233442111202,
          2.8842163773692251]    
    xM = np.log10(m)-np.log10(5.)
    return 10**np.polyval(Cf, xM)

def BolometricMagnitude(Lum=None, R=None, Teff=None):
    """
    returns bolometric magnitude for Lum in solar luminosity. if R
    (in Rsol) and Teff (in K) are given, these will be used to compute
    the luminosity instead.
    """
    if (not R is None) and (not Teff is None):
        Lum = R**2 *(T/SI.Teffsol)**4
    return SI.Mbolsol-2.5*np.log10(Lum)

def ApproximateMassFromSpectralType(st, lum_class='V'):
    """
    approximate mass from spectral type and luminosity class ('I',
    'II', 'III', 'IV' or 'V')

    st can be a list or np.ndarray, but lum class needs to be a
    scalar
    """
    # main sequence:
    stV =['O3', 'O5', 'O6', 'O8',
          'B0', 'B3', 'B5', 'B8',
          'A0', 'A5', 'F0', 'F5',
          'G0', 'G5', 'K0', 'K5',
          'M0', 'M2', 'M5', 'M8']
    MV = [120, 60, 37, 23,
          17.5, 7.6, 5.9, 3.8,
          2.9, 2.0, 1.6, 1.4,
          1.05, 0.92, 0.79, 0.67,
          0.51, 0.40, 0.21, 0.06 ]

    # giants:
    stIII = ['B0', 'B5', 'A0', 'G0', 'G5', 'K0', 'K5', 'M0']
    MIII = [20, 7, 4, 1.0, 1.1, 1.1, 1.2, 1.2]

    # super giants:
    stI = ['O5', 'O6', 'O8',
           'B0', 'B5',
           'A0', 'A5',
           'F0', 'F5',
           'G0', 'G5',
           'K0', 'K5',
           'M0', 'M2']
    MI = [70, 40, 28,
          25, 20,
          16, 13,
          12, 10,
          10, 12,
          13, 13,
          13, 19]
    
    if lum_class=='V':
        return 10**np.interp(decimST(st), decimST(stV),
                                np.log10(MV))
    elif lum_class=='IV':
        return 0.5*(10**np.interp(decimST(st), decimST(stV),
                                     np.log10(MV))+
                    10**np.interp(decimST(st), decimST(stIII),
                                     np.log10(MIII)))
    elif lum_class=='III':
        return 10**np.interp(decimST(st), decimST(stIII),
                                np.log10(MIII))
    elif lum_class=='II':
        return 0.5*(10**np.interp(decimST(st), decimST(stI),
                                     np.log10(MI))+
                    10**np.interp(decimST(st), decimST(stIII),
                                     np.log10(MIII)))
    else:
        return 10**np.interp(decimST(st), decimST(stI),
                            np.log10(MI))

def EstimateApparentOrbitalSeparation(M1=2.0, M2='F2IV', P=1000.0,
                                      plx=0.01):
    """
    Assuming 2 masses (or 2 spectral types), a Period (in days) and a
    parallaxe (in arcsec), an estimated apparent seraration is
    computed (in arcseconds)
    """
    if isinstance(M1, str):
        M1 = ApproximateMassFromSpectralType(M1[:2], lum_class=M1[2:])
    if isinstance(M2, str):
        M2 = ApproximateMassFromSpectralType(M2[:2], lum_class=M2[2:])
    sep_au = Kepler3rdLaw(P=P, M1M2 = M1+M2)*SI.Rsol/SI.au
    return sep_au*plx

def orbitalAngularMomentum(a, M1, M2, e=0.0, specific=False):
    """
    a in AU, M1 and M2 in Msol
    
    see Tokovinin (2008), paragraph 5.5:

    - Eq. 4 for definition of orbital angular momentum J
    
    - end of first paragraph for definition of specific momentum (J/M)
    """
    J = np.sqrt(a*SI.AU*(1-e**2))*M1*M2*SI.Msol**2*\
           np.sqrt(SI.G/((M1+M2)*SI.Msol))
    if specific:
        J/((M1+M2)*SI.Msol)
    return J

def tag2mjd(tag):
    """
    take a FITS time tag ('YYYY-MM-DDThh:mm:ss.sss') and
    convert it to MJD (assumes universal time).
    """
    y = int(tag.split('-')[0])
    m = int(tag.split('-')[1])
    d = int(tag.split('-')[2].split('T')[0])
    hh = float(tag.split('-')[2].split('T')[1].split(':')[0])
    mm = float(tag.split('-')[2].split('T')[1].split(':')[1])
    ss = float(tag.split('-')[2].split('T')[1].split(':')[2])
    return date2mjd(y,m,d,hh,mm,ss)

def date2mjd(y,m,d,hh=0.0,mm=0.0,ss=0.0):
    """
    convert a UTC date to modified julian date.
    """
    a = (14-m)/12
    y += 4800 - a
    m += 12*a - 3
    mjd = d + (153*m+2)/5 + 365*y +y/4 -y/100 +y/400 - 32045
    mjd += hh/24. + mm/(24*60.) + ss/(24*3600.0)
    mjd = mjd-2400001.0
    return mjd

def tag2lst(tag, longitude=-70.4049868):
    """ take a FITS time tag ('YYYY-MM-DDThh:mm:ss.sss') and convert
        it to LST (local sidereal time). Default longitude is for
        Paranal (value from FITS files).
    """
    y = int(tag.split('-')[0])
    m = int(tag.split('-')[1])
    d = int(tag.split('-')[2].split('T')[0])
    hh = float(tag.split('-')[2].split('T')[1].split(':')[0])
    mm = float(tag.split('-')[2].split('T')[1].split(':')[1])
    ss = float(tag.split('-')[2].split('T')[1].split(':')[2])
    return date2lst(y,m,d,hh,mm,ss, longitude=longitude)

def date2lst(y,m,d,hh=0.0,mm=0.0,ss=0.0, longitude=-70.40498688):
    """
    local apparent sidereal time in decimal hour as a function of UT
    date and time. Default longitude is Paranal, as given in FITS files.

    http://aa.usno.navy.mil/faq/docs/GAST.php
    """
    D = date2mjd(y,m,d,hh,mm,ss)+2400000.5-2451545.0
    GMST = 18.697374558 + 24.06570982441908*D
    Omega = 125.04-0.052954*D # Longitude of the ascending node of the Moon
    L = 280.47 + 0.98565*D # Mean Longitude of the Sun
    epsilon = 23.4393 - 0.0000004*D # obliquity
    # -- equation of equinox:
    eqeq = (-0.000319*np.sin(Omega*np.pi/180)-
            0.000024*np.sin(2*L*np.pi/180))*np.cos(epsilon*np.pi/180)
    GAST = GMST + eqeq
    LST = GAST + longitude/15.  
    return LST%24.
    
def degree2str(x):
    """
    convert decimal degrees (or decimal hours) to 'dd:mm:ss.sss' (or hh:mm:ss.sss')
    """
    if isinstance(x,list):
        return [degree2str(xx) for xx in x]
    h = '%d'%(int(x))
    x = abs(x)
    m = ('%2d'%(int(60.*(x-int(x))))).replace(' ','0')
    s = ('%6.3f'%(3600*(x-int(x)-int(60.*(x-int(x)))/60.))).replace(' ','0')
    return h+':'+m+':'+s

def ESOcoord2decimal(x):
    """
    convert format used in ESO ('hhmmss.sss' or 'ddmmss.sss') to decimal hours
    (or degrees) """
    if isinstance(x, str):
        x = float(x)
    si = np.sign(x)
    x = abs(x)
    h = int(x*1e-4)
    m = int((x-1e4*h)*1e-2)
    s = x-1e4*h-1e2*m
    return si*(h+m/60.+s/3600.) 
    
def interfVisibility(diam_mas, B_m, lambda_um):
    x = np.pi*np.array(diam_mas)*(np.pi/180)*(1/3600000.)*\
        np.array(B_m)/(np.array(lambda_um)*1e-6)
    res = np.array(2*special.jv(1, x)/x)
    return res

 
