#! /usr/bin/env python3
# -*- coding: iso-8859-15 -*-
try:
   from astropy.io import fits as pyfits
except:
   import pyfits
import numpy as np
import time
import os
import glob
import argparse
import sys


##
# Create an issue (and print it)
def create_issue(str):
    print(('  '+str))
    return [str]

##
# Compare headers
def compare_header(h_r, h_n):
    issues = []
    for h in h_r:
        if h not in h_n:
            issues += create_issue ('keyword "'+h+'" not in product')
        elif h_r[h] != h_n[h]:
            issues += create_issue ('keyword "'+h+'" changes value '+str(h_r[h])+' -> '+str(h_n[h]))
    for h in h_n:
        if h not in h_r:
            issues += create_issue ('keyword "'+h+'" new in product')
    return issues

##
# Compare tables
def compare_recarray(h_r, h_n, rdiff, adiff):
    issues = []
    maxstr = str(np.max([len(s) for s in h_r.names]))
    for c in h_r.names:
        if c not in h_n.names:
            issues += create_issue ('column "'+c+'" not in product')
            continue
        
        # check differences if strings
        if type(h_r[c][0]) == str:
            test = [1 if h_r[c][s] != h_n[c][s] else 0 for s in range(len(h_r[c]))]
            ndiff = np.sum(np.array(test))
            if ( ndiff > 0):
                issues += create_issue ('column "'+c+'" has %i strings with different values'%ndiff)
            continue
        
        # compute numerical differences
        try:
            dd = np.abs(h_r[c].flatten() - h_n[c].flatten());
            ss = np.abs(h_r[c].flatten() + h_n[c].flatten());
        except:
            dd = h_r[c].flatten() != h_n[c].flatten();
            ss = h_r[c].flatten() * 0 + 1;
            
        delta_abs = np.nanmean(dd)
        delta_rel = np.nansum(dd) / (np.nansum(ss) * 2.0 + 1e-20);
        delta_max = np.nanmax(dd);

        if delta_rel > rdiff or delta_abs > adiff:
            issues += create_issue (('column %-'+maxstr+'s  %e (abs)    %e (rel)    %e (max abs)')%(c, delta_abs, delta_rel, delta_max))
                    
    for c in h_n.names:
        if c not in h_r.names:
            issues += create_issue ('column "'+c+'" not in reference')
    
    return issues

##
# Compare image
def compare_image(h_r, h_n, rdiff, adiff):
    issues = []
    delta_abs = np.mean(np.abs(h_r.flatten() - h_n.flatten()))
    delta_rel = np.sum(np.abs(h_r.flatten() - h_n.flatten())) / (np.sum(np.abs(h_r.flatten() + h_n.flatten())) * 2.0 + 1e-20)
    
    if delta_rel > rdiff or delta_abs > adiff:
        issues += create_issue ('image %e (abs) %e (rel)'%(delta_abs, delta_rel))

    return issues

##
# Implement options and help

usage = """
description:
  Compare two GRAVITY files (so far only tables)
"""
examples = """
examples:
 
"""

parser = argparse.ArgumentParser(description=usage, epilog=examples,formatter_class=argparse.RawDescriptionHelpFormatter)
TrueFalse = ['TRUE','FALSE']

parser.add_argument("--ref", dest="file_ref", type=str,
                    default="GRAVI.2016-01-21T07:01:57.790_visdualsciraw_org4.fits",
                    help="Reference file for the comparison")

parser.add_argument("--new", dest="file_new", type=str,
                    default="GRAVI.2016-01-21T07:01:57.790_visdualsciraw.fits",
                    help="Reference file for the comparison")

parser.add_argument("--rdiff-threshold", dest="rdiff", default=0.0, type=float,
                  help="Threshold for the relative difference between numerical values")

parser.add_argument("--adiff-threshold", dest="adiff", default=0.0, type=float,
                  help="Threshold for the absolute difference between numerical values")

parser.add_argument("--headers", dest="headers",default='TRUE',choices=TrueFalse,
                  help="Check headers [TRUE]")

parser.add_argument("--tables", dest="tables",default='TRUE',choices=TrueFalse,
                  help="Check tables [TRUE]")

parser.add_argument("--images", dest="images",default='TRUE',choices=TrueFalse,
                  help="Check tables [TRUE]")


##
## Main code
if __name__ == "__main__":
    
    # Parse arguments
    argoptions = parser.parse_args()
    
    file_ref = argoptions.file_ref
    print(('\nref = '+file_ref))
    
    file_new = argoptions.file_new
    print(('new = '+file_new))

    # Open FITS files
    hdulist_r = pyfits.open(os.path.realpath(file_ref))
    hdulist_n = pyfits.open(os.path.realpath(file_new))

    # Init the number of differences
    all_issues = []
    
    # Verify number of HDU
    if len(hdulist_r) != len(hdulist_n):
        all_issues += create_issue ("\nIncompatible number of EXTENSION")
        # sys.exit(len(all_issues))

    # Main header
    if argoptions.headers == 'TRUE':
        print ('\nCompare main headers')
        all_issues += compare_header (hdulist_r[0].header, hdulist_n[0].header)

    # Compare first image
    if hdulist_r[0].data is not None \
      and hdulist_n[0].data is not None:
        print ('\nCompare first images')
        all_issues += compare_image (hdulist_r[0].data, hdulist_n[0].data,
                                     argoptions.adiff, argoptions.rdiff);

    # Get all couple of EXTNAME, EXTVER in reference
    extname = []
    for h in hdulist_r[1:]:
        if h.header['EXTNAME'] == 'IMAGING_DATA_ACQ':
            continue
        if 'EXTVER' in h.header:
            tmp = (h.header['EXTNAME'],h.header['EXTVER'])
        else:
            tmp = (h.header['EXTNAME'],None)
        extname.append (tmp)

    # Verify each HDU
    for i,j in extname:
        xtension = hdulist_r[i,j].header['XTENSION']
        print(('\nExtension '+i+','+str(j)+' ('+xtension+')'))        

        hdu_r = hdulist_r[i,j]
        try:
            hdu_n = hdulist_n[i,j]
        except:
            all_issues += create_issue ("\nCannot get %s in new"%xtension)
            continue
        
        # Verify header
        if argoptions.headers == 'TRUE':
            all_issues += compare_header (hdu_r.header, hdu_n.header)
    
        # Verify the table content
        if 'TABLE' in xtension and argoptions.tables == 'TRUE':
            all_issues += compare_recarray (hdu_r.data, hdu_n.data,
                                            argoptions.adiff, argoptions.rdiff)
            
        # Verify the table content
        if 'IMAGE' in xtension and argoptions.images == 'TRUE':
            all_issues += compare_image (hdu_r.data, hdu_n.data,
                                         argoptions.adiff, argoptions.rdiff)
            
    # Close FITS files
    hdulist_r.close()    
    hdulist_n.close()    

    # Return the number of difference
    sys.exit(len(all_issues))
    
