# Python script for writing spatial LBM simulation results to .vts-files for visualisation with ParaView
# ------------------------------------------------------------------------------------------------------
from matplotlib import cm
import numpy as np
from sys import argv

import matplotlib.pyplot as plt

print("=============================================================")
print("This script uses the library PyEVTK.")
print("Source: https://github.com/paulo-herrera/PyEVTK")
print("Wiki: https://vtk.org/Wiki/VTK/Writing_VTK_files_using_python")
print("Please make sure it is installed.")
print("=============================================================")

# Library for writing data to .vts-files:
from evtk.hl import imageToVTK
from evtk.hl import gridToVTK 
# Source: https://vtk.org/Wiki/VTK/Writing_VTK_files_using_python
# https://github.com/paulo-herrera/PyEVTK


# Define time steps for which spatial results are extracted:
# ----------------------------------------------------------
number_of_time_steps = 58  # Select number of time steps (if desired)

my_times = np.zeros(number_of_time_steps)

for i in range(0,number_of_time_steps):
    my_times[i] = (i+1)*5000.0 # Increment of 5000 time steps

my_times = np.append([1],my_times)

my_times = my_times.astype(int)
my_times = my_times.astype(str)

# Alternative definition of selected time steps:
# my_times = ['1','500','1000','1500','5000','10000','15000','20000','25000','30000'] # Uncomment if only selected time steps are of interest.
# Note: please make sure that the corresponding results files '1.plt','500.plt' etc. exist.


def read_parameters(my_times):

    my_path = my_times + '.plt'
    x, y, z, rho, upx, upy, upz, p, obst = np.loadtxt(my_path,
                                             unpack = True,
                                             delimiter = '\t',
                                             skiprows = 1,
                                             usecols = (0,1,2,3,4,5,6,7,8))

    return x, y, z, rho, upx, upy, upz, p, obst

# ---------------------------------------------------------------------

for t in my_times:
    print('Reading results file for time step t =', t, 'ts...')


    x, y, z, rho, upx, upy, upz, p, obst = read_parameters(t)

    x = x.astype(int)
    y = y.astype(int)
    z = z.astype(int)

# Make sure that the model dimensions lx, ly, lz are correct
    lx = 200
    ly = 200
    lz = 200

    obst_new = np.zeros((lx,ly,lz))
    rho_new = np.zeros((lx,ly,lz))
    p_new = np.zeros((lx,ly,lz))
    x_new = np.zeros((lx,ly,lz))
    y_new = np.zeros((lx,ly,lz))
    z_new = np.zeros((lx,ly,lz))
#
    for j in range(0,len(obst)):
        obst_new[z[j],y[j],x[j]] = obst[j] 
        rho_new[z[j],y[j],x[j]] = rho[j]
        p_new[z[j],y[j],x[j]] = p[j]
        x_new[z[j],y[j],x[j]] = x[j]
        y_new[z[j],y[j],x[j]] = y[j]
        z_new[z[j],y[j],x[j]] = z[j]

# ------------------------------------------
# Write vtk-data to files:
# ------------------------------------------
    nx, ny, nz = lx, ly, lz
    
    lx, ly, lz = nx*1.0, ny*1.0, nz*1.0
    dx, dy, dz = lx/nx, ly/ny, lz/nz

    ncells = nx * ny * nz
    npoints = (nx + 1) * (ny + 1) * (nz + 1)

    ncells = nx * ny * nz 
    npoints = (nx + 1) * (ny + 1) * (nz + 1)
    
# Coordinates
    X = np.arange(0, lx + 0.1*dx, dx, dtype='float64')
    Y = np.arange(0, ly + 0.1*dy, dy, dtype='float64')
    Z = np.arange(0, lz + 0.1*dz, dz, dtype='float64')
    x = np.zeros((nx + 1, ny + 1, nz + 1)) 
    y = np.zeros((nx + 1, ny + 1, nz + 1)) 
    z = np.zeros((nx + 1, ny + 1, nz + 1)) 
# Write coordinates 
    for k in range(nz + 1): 
        for j in range(ny + 1):
            for i in range(nx + 1): 
                x[i,j,k] = X[i] +  dx 
                y[i,j,k] = Y[j] +  dy 
                z[i,j,k] = Z[k] +  dz
        
# Variables 
    pressure = p_new   
    density = rho_new 
    obstacles = obst_new  

    print("Writing .vts-file for time step t = {:}".format(str(t)), "ts...")

    gridToVTK("./pressure_t_{:}".format(str(t)), x, y, z, cellData = {"pressure" : pressure}, pointData = None)
    gridToVTK("./density_t_{:}".format(str(t)), x, y, z, cellData = {"density" : density}, pointData = None)
    if (int(t) <= 1):
        gridToVTK("./obstacles_t_{:}".format(str(t)), x, y, z, cellData = {"obstacles" : obstacles}, pointData = None)
    # Note: The obstacles are only written for the initial time step.
    print("Done.")
    print("-----")
