import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
from scipy.optimize import root

# ============================================================
# KINETIC PARAMETERS
# ============================================================
# Turnover frequencies (kcat) for the enzymatic steps (min^-1)
kcat_I   = 1.2
kcat_II  = 0.3
kcat_III = 1.4
kcat_IV  = 0.5

# Michaelis-Menten, co-substrate, and inhibition constants (mmol L^-1)
KS1_I    = 30.2
Kc1_I    = 1.1
KiS1_I   = 4.3

KS2_II      = 8.2
KiS1_II     = 8.0
KiS5_II     = 14.3

KS3_III = 24.7

KS5_IV     = 15.2
Kc2_IV     = 3.1
KiS1_IV    = 8.3

# Total enzyme concentrations (mmol L^-1)
E_A = 15.0
E_B = 15.0
E_C = 15.0
E_D = 15.0

# ============================================================
# KINETICS (UPDATED MODEL)
# ============================================================

def rates(S1, S2, S3, S4, c1, c2, S5):
    """
    Calculates the reaction rates for the four-step enzymatic cascade.
    
    Parameters:
    -----------
    S1, S2, S3, S4 : float
        Concentrations of core pathway substrates/products (mmol L^-1).
    c1, c2 : float
        Concentrations of co-substrates/co-factors (mmol L^-1).
    S5 : float
        Concentration of the parallel/loop substrate (mmol L^-1).
        
    Returns:
    --------
    tuple of float
        Reaction rates (rI, rII, rIII, rIV) in mmol L^-1 min^-1.
    """
    # Reaction I: Substrate S1 + co-substrate c1 -> S2 + c2 (Inhibited by S1)
    rI = (
        kcat_I * E_A * S1 * c1
        / ((KS1_I + S1) * (Kc1_I + c1) * (1 + S1 / KiS1_I))
    )

    # Reaction II: Substrate S2 -> S3 (Inhibited by S1 and S5)
    rII = (
        kcat_II * E_B * S2
        / ((KS2_II + S2) * (1 + S1 / KiS1_II) * (1 + S5 / KiS5_II))
    )

    # Reaction III: Substrate S3 -> S4
    rIII = (
        kcat_III * E_C * S3
        / (KS3_III + S3)
    )

    # Reaction IV: Substrate S5 + co-substrate c2 -> c1 (Inhibited by S1)
    rIV = (
        kcat_IV * E_D * S5 * c2
        / ((KS5_IV + S5) * (Kc2_IV + c2) * (1 + S1 / KiS1_IV))
    )

    return rI, rII, rIII, rIV


# ============================================================
# INITIAL CONDITIONS & PLOTTING CONFIGURATIONS
# ============================================================
# Initial concentrations profile: [S1, S2, S3, S4, c1, c2, S5]
y0 = [50, 0, 0, 0, 2, 0, 50]
t = np.linspace(0, 180, 100)  # Simulation time horizon (min)

labels = ['$S_1$', '$S_2$', '$S_3$', '$S_4$', '$c_1$', '$c_2$', '$S_5$']
colors = ['#08306b', '#2171b5', '#6baed6', '#c6dbef',
          '#67001f', '#ce1256', '#f768a1']

# ============================================================
# 1. BATCH REACTOR (BR)
# ============================================================

def br_model(y, t):
    """
    Defines the system of Ordinary Differential Equations (ODEs) for a Batch Reactor.
    
    dx/dt = S * r
    """
    S1, S2, S3, S4, c1, c2, S5 = y
    rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)

    # Mass balances tracking compound production and consumption
    return [-rI, rI-rII, rII-rIII, rIII, -rI+rIV, rI-rIV, -rIV]

# Integrate Batch ODEs over the time horizon
sol_br = odeint(br_model, y0, t)

# Visualize Batch Reactor Profiles
plt.figure(figsize=(4.5, 4.1))
for i in range(7):
    plt.plot(t, sol_br[:, i], color=colors[i], label=labels[i])

ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$t$ / min', fontsize=14)
plt.ylabel('$S_i$ / mmol L$^{-1}$', fontsize=14)
plt.legend(fontsize=9)
plt.xticks([0, 60, 120, 180], fontsize=14)
plt.grid(False)
plt.tight_layout()
plt.savefig('BR_all.pdf')
plt.show()


# ============================================================
# 2. CSTR (UNSTEADY STATE)
# ============================================================
F, V = 1.0, 30.0  # F: Volumetric flowrate (L min^-1), V: Working Volume (L)
Sin = y0          # Inlet concentrations assume standard initial values

def cstr_model(y, t):
    """
    Defines the transient system of ODEs for a Continuous Stirred-Tank Reactor.
    
    dx/dt = (F/V)*(x_in - x) + S * r
    """
    S1, S2, S3, S4, c1, c2, S5 = y
    rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)

    # Dynamic mass balances including convection and reaction terms
    return [
        F/V*(Sin[0]-S1) - rI,
        F/V*(Sin[1]-S2) + rI - rII,
        F/V*(Sin[2]-S3) + rII - rIII,
        F/V*(Sin[3]-S4) + rIII,
        F/V*(Sin[4]-c1) - rI + rIV,
        F/V*(Sin[5]-c2) + rI - rIV,
        F/V*(Sin[6]-S5) - rIV
    ]

# Integrate transient CSTR equations from initial filling state
sol_cstr = odeint(cstr_model, y0, t)

# Visualize Transient CSTR Profiles
plt.figure(figsize=(4.5, 4.1))
for i in range(7):
    plt.plot(t, sol_cstr[:, i], color=colors[i], label=labels[i])
    
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$t$ / min', fontsize=14)
plt.ylabel('$S_i$ / mmol L$^{-1}$', fontsize=14)
plt.legend(fontsize=9)
plt.grid(False)
plt.xticks([0, 60, 120, 180], fontsize=14)
plt.tight_layout()
plt.savefig('CSTR_unsteady_all.pdf')
plt.show()


# ============================================================
# 3. CSTR STEADY STATE vs SPACE TIME (TAU)
# ============================================================
tau_list = np.linspace(1, 180, 75)  # Range of space times (min) to evaluate
S_out = []                          # Container for steady state values

for tau in tau_list:
    F = 1
    V = tau * F                     # Adjust volume to match targeted space time

    def steady(y):
        """ Algebraic residue system representing steady-state mass balances (dx/dt = 0). """
        S1, S2, S3, S4, c1, c2, S5 = y
        rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)

        return [
            F/V*(Sin[0]-S1) - rI,
            F/V*(Sin[1]-S2) + rI - rII,
            F/V*(Sin[2]-S3) + rII - rIII,
            F/V*(Sin[3]-S4) + rIII,
            F/V*(Sin[4]-c1) - rI + rIV,
            F/V*(Sin[5]-c2) + rI - rIV,
            F/V*(Sin[6]-S5) - rIV
        ]

    # Solve the non-linear algebraic system using Levenberg-Marquardt algorithm
    sol = root(steady, y0, method='lm').x
    S_out.append(sol)

S_out = np.array(S_out)

# Visualize Steady-State Concentration profiles as a function of Space Time tau
plt.figure(figsize=(4.5, 4.1))
for i in range(7):
    plt.plot(tau_list, S_out[:, i], color=colors[i], label=labels[i])
    
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$\\tau$ / min', fontsize=14)
plt.xticks([0, 60, 120, 180], fontsize=14)
plt.xlim(-8, 188)
plt.ylabel('$S_i$ / mmol L$^{-1}$', fontsize=14)
plt.legend(fontsize=9)
plt.grid(False)
plt.tight_layout()
plt.savefig('CSTR_steady_tau.pdf')
plt.show()


# ============================================================
# 4. PFR STEADY STATE
# ============================================================
z = np.linspace(0, 180, 100)  # Axial coordinate along reactor length (cm)
u = 1.0                       # Linear velocity inside the tube (cm min^-1)

def pfr_model(y, z):
    """
    Defines the steady state ODEs for a Plug Flow Reactor along its spatial axis z.
    
    dx/dz = (S * r) / u
    """
    S1, S2, S3, S4, c1, c2, S5 = y
    rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)

    return [
        -rI/u,
        (rI-rII)/u,
        (rII-rIII)/u,
        rIII/u,
        (-rI+rIV)/u,
        (rI-rIV)/u,
        -rIV/u
    ]

# Integrate spatially over the PFR coordinate profile
sol_pfr = odeint(pfr_model, y0, z)

# Visualize Steady-State PFR Profiles
plt.figure(figsize=(4.5, 4.1))
for i in range(7):
    plt.plot(z, sol_pfr[:, i], color=colors[i], label=labels[i])
    
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$z$ / cm', fontsize=14)
plt.ylabel('$S_i$ / mmol L$^{-1}$', fontsize=14)
plt.legend(fontsize=9)
plt.grid(False)
plt.xticks([0, 60, 120, 180], fontsize=14)
plt.tight_layout()
plt.savefig('PFR_steady_all.pdf')
plt.show()


# ============================================================
# 5. PFR UNSTEADY STATE (1D PDE discretized via Method of Lines)
# ============================================================
Nz = 40                 # Number of spatial discretization intervals
L = 75                  # Total length of the PFR domain (cm)
dz = L / Nz             # Spatial step size
z = np.linspace(0, L, Nz)

# Set up global flattened initial condition matrix for all grid points
Y0 = np.zeros((Nz, 7))
Y0[:, 0] = 50           # Initial concentration profile for S1
Y0[:, 4] = 2            # Initial concentration profile for c1
Y0[:, 6] = 50           # Initial concentration profile for S5
Y0 = Y0.flatten()       # Flattened vector structure required for odeint

def pfr_unsteady(Y, t):
    """
    Solves the transient 1D convection-reaction PDE via spatial discretization 
    (Backward Finite Difference Scheme / Upwind Method for advection).
    """
    Y = Y.reshape((Nz, 7))
    dYdt = np.zeros_like(Y)

    # Boundary conditions at z = 0 remain fixed at feed values (No dynamics at input boundary)
    for i in range(1, Nz):
        S1, S2, S3, S4, c1, c2, S5 = Y[i]
        rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)

        # PDE discretization: dX/dt = -u*(dX/dz) + S*r
        dYdt[i, 0] = -u*(Y[i, 0]-Y[i-1, 0])/dz - rI
        dYdt[i, 1] = -u*(Y[i, 1]-Y[i-1, 1])/dz + rI - rII
        dYdt[i, 2] = -u*(Y[i, 2]-Y[i-1, 2])/dz + rII - rIII
        dYdt[i, 3] = -u*(Y[i, 3]-Y[i-1, 3])/dz + rIII
        dYdt[i, 4] = -u*(Y[i, 4]-Y[i-1, 4])/dz - rI + rIV
        dYdt[i, 5] = -u*(Y[i, 5]-Y[i-1, 5])/dz + rI - rIV
        dYdt[i, 6] = -u*(Y[i, 6]-Y[i-1, 6])/dz - rIV

    return dYdt.flatten()

# Integrate the resulting coupled ODE array across time
sol_pde = odeint(pfr_unsteady, Y0, t)

# Visualize Unsteady state progression profiles for target product S4
plt.figure(figsize=(4.5, 4.1))
for idx, ti in enumerate([0, 20, 40, 60, 80]):
    profile = sol_pde[ti].reshape((Nz, 7))
    plt.plot(z, profile[:, 3], label=f't={t[ti]:.0f} min')
    
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$z$ / cm', fontsize=14)
plt.ylabel('$S_4$ / mmol L$^{-1}$', fontsize=14)
plt.legend(fontsize=9)
plt.xticks([0, 25, 50, 75], fontsize=14)
plt.grid(False)
plt.tight_layout()
plt.savefig('PFR_unsteady_profiles.pdf')
plt.show()