#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 18 00:03:11 2017

@author: Felix Henningsen
"""

'''
The MetroPy Helper File:
    Takes care of plotnumbering and -labeling as well as heatmaps for 2D Plots
'''

# import necessary packages for handling
import numpy as np
# import image handling packages
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# dealing with wanted plotnumbering for different pdf-pages
def plotnumber(page, n, counter):
    if page == 'raw2d':
        if n < 8:
            return counter + 1 + 1
        if 8 <= n < 16:
            return counter + 1 + 2 + 5
        if 16 <= n < 24:
            return counter + 1 + 3 + 10
        if 24 <= n < 32:
            return counter + 1 + 4 + 15
        if 32 <= n < 40:
            return counter + 1 + 5 - 15
        if 40 <= n < 48:
            return counter + 1 + 6 - 10
        if 48 <= n < 56:
            return counter + 1 + 7 - 5
        if 56 <= n < 64:
            return counter + 1 + 8 
        if 64 <= n < 72:
            return 1 + 5*(n%8)
        if 72 <= n < 80:
            return 1 + 5 + 5*(n%8)

    if page == 'phasor':
        if n <= 6 or n >= 56:
            return counter + 1
        if 6 < n <= 14:
            return counter + 1 + 4
        if 14 < n <= 22:
            return counter + 1 + 8
        if 22 < n <= 30:
            return counter + 1 + 12
        if 30 < n <= 38:
            return counter + 1 - 12
        if 38 < n <= 46:
            return counter + 1 - 8
        if 46 < n <= 54:
            return counter + 1 - 4
    
    if page == 'phase':
        if n < 4:
            return n*3 + 1
        if 4 <= n < 8:
            return (n-4)*3 + 2
        if 8 <= n < 12:
            return (n-8)*3 + 3

# dealing with wanted plot-labeling for different pdf-pages
def plotlabel(page, n):
    if page == 'raw2d':
        if n < 8:
            return 'AT4 D%s FT' %(1 + int(n / 2))
        if 8 <= n < 16:
            return 'AT3 D%s FT' %(1 + int((n - 8) / 2))
        if 16 <= n < 24:
            return 'AT2 D%s FT' %(1 + int((n - 16) / 2))
        if 24 <= n < 32:
            return 'AT1 D%s FT' %(1 + int((n - 24) / 2))
        if 32 <= n < 40:
            return 'AT4 D%s SC' %(1 + int((n - 32) / 2))
        if 40 <= n < 48:
            return 'AT3 D%s SC' %(1 + int((n - 40) / 2))
        if 48 <= n < 56:
            return 'AT2 D%s SC' %(1 + int((n - 48) / 2))
        if 56 <= n < 64:
            return 'AT1 D%s SC' %(1 + int((n - 56) / 2))
        if 64 <= n < 72:
            return 'AT%s FC FT' %(4 - int((n - 64) / 2))
        if 72 <= n < 80:
            return 'AT%s FC SC' %(4 - int((n - 72) / 2))
        
    if page == 'phase':
        if n < 4:
            return '%s FT' %(4-n)
        if 4 <= n < 8:
            return '%s SC' %(8-n)
        if 8 <= 12:
            return 'FT-SC'
 
    if page == 'phasor':
        if n <= 6:
            return '%s D%s FT' %(4, int(n/2 + 1))
        if 6 < n <= 14:
            return '%s D%s FT' %(3, int((n-8) / 2 + 1))
        if 14 < n <= 22:
            return '%s D%s FT' %(2, int((n-16) / 2 + 1))
        if 22 < n <= 30:
            return '%s D%s FT' %(1, int((n-24) / 2 + 1))
        if 30 < n <= 38:
            return '%s D%s SC' %(4, int((n-32) / 2 + 1))
        if 38 < n <= 46:
            return '%s D%s SC' %(3, int((n-40) / 2 + 1))
        if 46 < n <= 54:
            return '%s D%s SC' %(2, int((n-48) / 2 + 1))
        if n >= 56:
            return '%s D%s SC' %(1, int((n-56) / 2 + 1))

# creating fast numpy histogramm heatmaps for 2D plots        
def heatmap2d(fig, data, n, limits1, limits2, bins, size):
    heatmap, xedges, yedges = np.histogram2d(data[:,n], data[:,n+1],
                                             bins=bins)
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    # fix ranges
    if limits1 > limits2:
        plt.xlim([-limits1, limits1])
        plt.ylim([-limits1, limits1])
    else:
        plt.xlim([-limits2, limits2])
        plt.ylim([-limits2, limits2])      
    # make colorbar align with image
    im = plt.imshow(heatmap.T, extent=extent, origin='lower')
    ax = plt.gca()
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    # format tightly
    fig.set_size_inches(size, forward=True) 
        
def chunking(arr, blocksize):
    # calculate the number of fully populated chunks according to blocksize
    even_size = int( (len(arr) - len(arr)%blocksize) ) 
    n_full_chunks = int( even_size / blocksize)
    # make a cached array containing only the even chunk elements
    arr_even = np.array(arr[:even_size])   
    #take care of the remaining values
    arr_odd = np.array(arr[even_size:])
    
    # split the array accordingly into even chunks
    arr = np.array_split(arr_even, n_full_chunks)
    arr = np.array(arr)
    
    return arr, arr_odd
    


