#! /usr/bin/env python3
import argparse
import glob
from astropy.io import fits
import os
import os.path
from shutil import move, copyfile, rmtree


def format_value(value):
    value = value.strip()
    try:
        value=int(value)
    except ValueError:
        try:
            value=float(value)
        except ValueError:
            pass
    return value


def process_patch(s):
    # Parse comment if given
    if '/' in s:
        s, comment = s.split('/',1)
        comment = comment.strip()
    else:
        comment = None
    # Parse key and value
    if '=' in s:
        mode = '='
    elif '~' in s:
        mode = '~'
    else:
        raise Exception("Patch file element must be an update (=) or an addition (~)!")
    key, value = s.split(mode,1)
    return (mode, key.strip().replace('.',' '), format_value(value.strip()))


def get_checksum_comment(filename):
    with fits.open(filename) as hdulist:
        checksum = hdulist[0].header.comments['CHECKSUM']
    return checksum


def update_or_insert_keyword(header, mode, key, value):
    if (key in header) and (mode == '='):
        header["HIERARCH "+key] = value
        print(("    {0} = {1}".format(key, value)))
    elif (key in header) and (mode == '~'):
        old_value = header["HIERARCH "+key]
        header["HIERARCH "+key] += value
        new_value = header["HIERARCH "+key]
        print(("    {0} ~ {1}: {2} -> {3}".format(key, value, old_value, new_value)))
    else:
        for k in list(header.keys()):
            if 'ESO' not in k:
                pass
            elif k < key:
                pass
            else:
                header.insert(k, ("HIERARCH "+key,value))
                break
        print(("    {0} = {1}".format(key, value)))


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='Patches the primary header of some FITS files.')
    parser.add_argument('files', type=str, nargs='+', help='FITS file(s) to patch')
    parser.add_argument('--revert', '-r', action='store_true', help="Revert all previously applied patches.")
    parser.add_argument('--patches', '-p', type=process_patch, default=[], action='append', help='Command line patch with ESO.SOMETHING.KEY=value format')
    parser.add_argument('--patch_file', '-f', type=str, nargs='?', default=None, help='Patch file with ESO.SOMETHING.KEY=value format')
    args = parser.parse_args()
    
    for filename in args.files:

        print(("Processing {0}...".format(filename)))
        print(("  [Current]  {0}.".format(get_checksum_comment(filename))))
        if os.path.isfile(filename.replace('.fits', '.fits_saved')):
            print(("  [ Saved ]  {0}.".format(get_checksum_comment(filename.replace('.fits', '.fits_saved')))))
        
        if args.revert:
            print("  Reverting file...")
            if os.path.isfile(filename.replace('.fits', '.fits_saved')):
                os.remove(filename)
                move(filename.replace('.fits', '.fits_saved'), filename)
            if os.path.isfile(filename.replace('.fits', '.fits.patch.log')):
                os.remove(filename.replace('.fits', '.fits.patch.log'))
        
        if args.patches or args.patch_file:
            
            # Back-up file if needed
            if not os.path.isfile(filename.replace('.fits', '.fits_saved')):
                move(filename, filename.replace('.fits', '.fits_saved'))
                copyfile(filename.replace('.fits', '.fits_saved'), filename)
            
            with open(filename.replace('.fits', '.fits.patch.log'), 'a') as patch_log:
                
                print("  Patching...")
                with fits.open(filename, mode='update') as hdulist:
                    for (mode,key,value) in args.patches:
                        update_or_insert_keyword(hdulist[0].header, mode, key, value)
                        patch_log.write("{1} {0} {2}\n".format(mode, key, value))
                    if args.patch_file:
                        with open(args.patch_file) as patch_file:
                            for patch in patch_file:
                                mode, key, value = process_patch(patch)
                                update_or_insert_keyword(hdulist[0].header, mode, key, value)
                                patch_log.write("{1} {0} {2}\n".format(mode, key, value))
                    hdulist.flush(output_verify='ignore')
        print(("  [ Final ]  {0}.".format(get_checksum_comment(filename))))
    
    print("Done!")
