# Script for creating snapshots (xyz files) for animations of vibrational motions of atoms in IR spectroscopy
# (C) Kai Sellschopp, TUHH Hamburg, Germany
import numpy as np
import poscarRW as prw
import xyzRW as xyzrw


def readEigenvectors(filename='eigenvectors.txt'):
    """
    This function reads an 'eigenvectors.txt' file and returns a list of
    eigenfrequencies (in four different units) as well as a set of displacements
    for each frequency.
    """
    # read file
    with open(filename) as f:
        ev_file = f.read()

    # split into different entries (one entry per frequency)
    ev_entries = ev_file.split('\n \n  ')[1:]

    # initialize arrays
    frequencies = np.array([]).reshape(0,4)
    positions = np.array([]).reshape(0,3)
    displacements = np.array([]).reshape(0,0,3)

    # iterate over entries
    for i in range(len(ev_entries)):
        ev_entry = ev_entries[i].split('\n')

        # extract frequencies
        frequencies_entry = np.array(np.array(ev_entry[0].split())[[3,5,7,9]], dtype='f')
        frequencies = np.append(frequencies, frequencies_entry.reshape(1,4), axis=0)

        # initialize array for displacements
        displacements_entry = np.array([]).reshape(0,3)
        
        # iterate over data lines of an entry (starting from third line)
        for j in range(2,len(ev_entry)):
            if ''.join(ev_entry[j].split())=='':
                # stop if the line is empty
                break
            else:
                # extract displacements
                disp_line = np.array(ev_entry[j].split()[-3:], dtype='f')
                displacements_entry = np.append(displacements_entry, disp_line.reshape(1,3), axis=0)

        # insert into global displacements array
        if displacements.size==0:
            # reshape displacements array accordingly
            displacements = displacements.reshape(0,len(displacements_entry),3)

        # append entry
        displacements = np.append(displacements, displacements_entry.reshape(1,len(displacements_entry),3), axis=0)

    # return frequencies and displacements
    return frequencies, displacements


def calcAnimation(positions, displacements, scaling=1.0, steps=100):
    """
    This function calculates a sequence of positions starting with given positions
    and then shifting all positions according to given displacements.
    """

    # initialize output array
    output = positions.reshape(1,len(positions),3).repeat(steps, axis=0)

    # calculate displacement scaling for each step
    disp_factor = scaling*np.sin(np.linspace(0,2*np.pi,steps,endpoint=False))
    disp_factor = disp_factor.reshape(steps,1,1).repeat(len(positions),axis=1).repeat(3,axis=2)

    # calculate all steps
    output = output + disp_factor*displacements.reshape(1,len(displacements),3).repeat(steps, axis=0)

    return output


if __name__=="__main__":
    import sys

    # get data from files
    poscar = prw.readPOSCAR(sys.argv[1])
    frequ, disp = readEigenvectors()

    # get positions in cartesian coordinates
    positions = poscar[0]*np.dot(poscar[4],poscar[1])

    # create atom labels for xyz-file
    nAtoms = np.sum(poscar[3])
    tAtoms = []
    for atomLabel, atomNumber in zip(poscar[2], poscar[3]):
        for i in range(atomNumber):
            tAtoms.append(atomLabel)
    
    # iterate over different eigenfrequencies
    for f, d in zip(frequ, disp):
        # output
        print 'Writing animation for frequency: {:7.0f}'.format(f[2])
        
        # calculate animation steps
        anim_steps = calcAnimation(positions, d)

        # repeat nAtoms and tAtoms to match the number of steps
        nAtoms_steps = np.array([nAtoms]).repeat(len(anim_steps))
        tAtoms_steps = np.array(tAtoms).reshape(1,len(tAtoms)).repeat(len(anim_steps),axis=0)
        
        # create filename
        filename = 'anim_frequ={:.0f}.xyz'.format(f[2])

        # write animation to xyz file
        xyzrw.writeXYZ(filename, nAtoms_steps, tAtoms_steps, anim_steps)

    print 'DONE!'
