# Script for plotting calculated IR spectra
# (C) Kai Sellschopp, TUHH Hamburg, Germany
import numpy as np
from matplotlib import pyplot as plt


def creategausspeaks(peakpos, peakheight=1.0, peakwidth=2.0, res=10, xgrid=None):
    """
    """
    # create x-axis values, if no grid is given
    if np.any(xgrid==None):
        xmin = np.min(peakpos)-10*np.max(peakwidth)
        xmax = np.max(peakpos)+10*np.max(peakwidth)
        nSteps = np.round(res*(xmax-xmin)/np.min(peakwidth) + 1)
        xgrid = np.linspace(xmin, xmax, nSteps)
    
    # calculate gauss peaks at each peak position with the given height and width
    if np.size(peakpos)==1:
        # if there is only one peak, it's simple
        gausspeaks = gausscurve(xgrid, peakpos, peakwidth, peakheight)
    else:
        # else initialize an array containing all gauss functions
        gausspeaks = np.zeros((np.size(peakpos), np.size(xgrid)))
        height = peakheight
        width = peakwidth

        # write all gauss functions in the array
        for i in range(np.size(peakpos)):
            if np.size(peakheight)!=1:
                height = peakheight[i]
            if np.size(peakwidth)!=1:
                width = peakwidth[i]

            gausspeaks[i] = gausscurve(xgrid, peakpos[i], width, height)

        # sum up all gauss functions
        gausspeaks = np.sum(gausspeaks, axis=0)

    return xgrid, gausspeaks


def gausscurve(x, center=0.0, width=1.0, height=1.0/np.sqrt(2.0*np.pi)):
    """
    Returns the gauss function values of array x.
    """
    return height*np.exp(-(x-center)**2/(2.0*width**2))


def plotIR(x, y, labels=None, title="IR spectrum", xlabel="Wavenumber 1/cm", ylabel="Intensity (a.u.)", invertX=True, padding=0.1, datatitle="IR.png"):
    """
    """
    # create figure and axes
    fig, ax = plt.subplots()

    # plot data
    ax.plot(x,y)

    # set title and labels
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    if labels!=None:
        ax.legend(labels, loc='best')

    # set x-axis limits
    x_upper = np.max(x)+padding*np.max(x)
    x_lower = np.min(x)-padding*np.max(x)
    if invertX:
        ax.set_xlim(xmin=np.ceil(x_upper), xmax=np.floor(x_lower))
    else:
        ax.set_xlim(xmin=np.floor(x_lower), xmax=np.ceil(x_upper))

    # set y-axis limits
    ax.set_ylim(ymin=0.0, ymax=1.1)

    # save
    fig.savefig(datatitle)
    plt.close('all')


def readResults(filename="Results.txt"):
    """
    """
    # read file
    with open(filename) as f:
        data = f.read()

    # get lines
    lines = data.split('\n')

    # initialize output arrays
    peakpos = []
    peakheight = []

    # parse data
    for line in lines:
        if line!='':
            line = line.split()
            peakpos.append(line[1])
            peakheight.append(line[2])

    # convert to numpy array and return
    return np.array(peakpos, dtype='float64'), np.array(peakheight, dtype='float64')


## MAIN PROGRAM ##
# execute only if module is called from __main__ (as a script)
if __name__ == "__main__":
    import sys

    filename = sys.argv[1]
    peakpos, peakheight = readResults(filename)

    peaks = creategausspeaks(peakpos, peakheight)

    plotIR(peaks[0], peaks[1])
