"""
Dynamic Optimization of a Complex Enzymatic Reaction Cascade using Pyomo.DAE
=============================================================================

PREREQUISITES:
--------------
To execute this script, ensure you have the following packages and solvers installed:
1. Pyomo: pip install pyomo
2. Ipopt Solver: 
   - Via conda: conda install -c conda-forge ipopt
   - Or download the binary from your package manager/source and add it to your system PATH.

Description:
------------
This script builds a Pyomo ConcreteModel over a continuous time horizon (12 hours) 
to determine the optimal enzyme dosing allocations (E_A, E_B, E_C, E_D) and time-dependent 
substrate feeding rate u(t) that maximizes the average productivity of component S4.
The differential equations are discretized using a Backward Euler finite difference scheme.
"""

from pyomo.environ import *
from pyomo.dae import *
import matplotlib.pyplot as plt

# ------------------------------------------------------------
# MODEL SETUP
# ------------------------------------------------------------
m = ConcreteModel(name="Enzymatic_Cascade_Optimization")

# Define time domain horizon (12 hours * 60 min/hr = 720 min)
t_final = 12 * 60
m.t = ContinuousSet(bounds=(0, t_final))

# ------------------------------------------------------------
# STATE & CONTROL VARIABLES
# ------------------------------------------------------------
# State variables tracking concentrations over time (mM)
m.S1 = Var(m.t, bounds=(0, 500), doc="Substrate 1 concentration profile")
m.S2 = Var(m.t, bounds=(0, 500), doc="Substrate 2 concentration profile")
m.S3 = Var(m.t, bounds=(0, 500), doc="Substrate 3 concentration profile")
m.S4 = Var(m.t, bounds=(0, 500), doc="Product 4 concentration profile")
m.c1 = Var(m.t, bounds=(0, 500), doc="Co-substrate 1 concentration profile")
m.c2 = Var(m.t, bounds=(0, 500), doc="Co-substrate 2 concentration profile")
m.S5 = Var(m.t, bounds=(0, 500), doc="Substrate 5 concentration profile")

# Time-dependent Control variable: feed rate of substrate 1 and 5
m.u = Var(m.t, bounds=(0, 10), initialize=0, doc="Substrate feed rate control profile")

# Time derivatives of state variables
m.dS1dt = DerivativeVar(m.S1, wrt=m.t)
m.dS2dt = DerivativeVar(m.S2, wrt=m.t)
m.dS3dt = DerivativeVar(m.S3, wrt=m.t)
m.dS4dt = DerivativeVar(m.S4, wrt=m.t)
m.dc1dt = DerivativeVar(m.c1, wrt=m.t)
m.dc2dt = DerivativeVar(m.c2, wrt=m.t)
m.dS5dt = DerivativeVar(m.S5, wrt=m.t)

# ------------------------------------------------------------
# INITIAL CONDITIONS (t = 0)
# ------------------------------------------------------------
# Bound initial substrates within acceptable operating ranges
m.S1[0].setlb(1)
m.S1[0].setub(50)

m.S5[0].setlb(1)
m.S5[0].setub(50)

# Fix intermediate profiles and co-factor initial states
m.S2[0].fix(0)
m.S3[0].fix(0)
m.S4[0].fix(0)
m.c1[0].fix(2)
m.c2[0].fix(0)

# ============================================================
# KINETIC FIXED PARAMETERS
# ============================================================
kcat_I   = 1.2
kcat_II  = 0.3
kcat_III = 1.4
kcat_IV  = 0.5

KS1_I    = 30.2
Kc1_I    = 1.1
KiS1_I   = 4.3

KS2_II   = 8.2
KiS1_II  = 8.0
KiS5_II  = 14.3

KS3_III  = 24.7

KS5_IV   = 15.2
Kc2_IV   = 3.1
KiS1_IV  = 8.3

# ============================================================
# ENZYME ALLOCATION VARIABLES (Optimization Decisions)
# ============================================================
m.E_A = Var(bounds=(0, 20), doc="Enzyme A concentration allocation")
m.E_B = Var(bounds=(0, 20), doc="Enzyme B concentration allocation")
m.E_C = Var(bounds=(0, 20), doc="Enzyme C concentration allocation")
m.E_D = Var(bounds=(0, 20), doc="Enzyme D concentration allocation")

# ============================================================
# KINETIC RATE EXPRESSIONS
# ============================================================
def rI(m, t):
    """Reaction I: S1 + c1 -> S2 + c2 (Inhibited by S1)"""
    return (
        kcat_I * m.E_A * m.S1[t] * m.c1[t]
        / ((KS1_I + m.S1[t]) * (Kc1_I + m.c1[t]) * (1 + m.S1[t] / KiS1_I))
    )

def rII(m, t):
    """Reaction II: S2 -> S3 (Inhibited by S1 and S5)"""
    return (
        kcat_II * m.E_B * m.S2[t]
        / ((KS2_II + m.S2[t]) * (1 + m.S1[t] / KiS1_II) * (1 + m.S5[t] / KiS5_II))
    )

def rIII(m, t):
    """Reaction III: S3 -> S4"""
    return (
        kcat_III * m.E_C * m.S3[t] / (KS3_III + m.S3[t])
    )

def rIV(m, t):
    """Reaction IV: S5 + c2 -> c1 (Inhibited by S1)"""
    return (
        kcat_IV * m.E_D * m.S5[t] * m.c2[t]
        / ((KS5_IV + m.S5[t]) * (Kc2_IV + m.c2[t]) * (1 + m.S1[t] / KiS1_IV))
    )

# ------------------------------------------------------------
# MASS BALANCE CONSTRAINTS (ODEs)
# ------------------------------------------------------------
def _S1(m, t):
    if t == 0: return Constraint.Skip
    return m.dS1dt[t] == -rI(m, t) + m.u[t]

def _S2(m, t):
    if t == 0: return Constraint.Skip
    return m.dS2dt[t] == rI(m, t) - rII(m, t)

def _S3(m, t):
    if t == 0: return Constraint.Skip
    return m.dS3dt[t] == rII(m, t) - rIII(m, t)

def _S4(m, t):
    if t == 0: return Constraint.Skip
    return m.dS4dt[t] == rIII(m, t)

def _c1(m, t):
    if t == 0: return Constraint.Skip
    return m.dc1dt[t] == -rI(m, t) + rIV(m, t)

def _c2(m, t):
    if t == 0: return Constraint.Skip
    return m.dc2dt[t] == rI(m, t) - rIV(m, t)

def _S5(m, t):
    if t == 0: return Constraint.Skip
    return m.dS5dt[t] == -rIV(m, t) + m.u[t]

def _ET(m):
    """Total enzyme concentration budget constraint."""
    return m.E_A + m.E_B + m.E_C + m.E_D <= 50

# Register Constraints in Pyomo Model
m.eq_S1 = Constraint(m.t, rule=_S1)
m.eq_S2 = Constraint(m.t, rule=_S2)
m.eq_S3 = Constraint(m.t, rule=_S3)
m.eq_S4 = Constraint(m.t, rule=_S4)
m.eq_c1 = Constraint(m.t, rule=_c1)
m.eq_c2 = Constraint(m.t, rule=_c2)
m.eq_S5 = Constraint(m.t, rule=_S5)
m.eq_ET = Constraint(rule=_ET)

# ------------------------------------------------------------
# OBJECTIVE FUNCTION
# ------------------------------------------------------------
# Maximize average space-time yield of target product S4
m.obj = Objective(
    expr=m.S4[t_final] / t_final,
    sense=maximize
)

# ------------------------------------------------------------
# DISCRETIZATION (Method of Lines via Backward Euler)
# ------------------------------------------------------------
discretizer = TransformationFactory('dae.finite_difference')
discretizer.apply_to(m, nfe=250, scheme='BACKWARD')

# ------------------------------------------------------------
# SOLVER CALL
# ------------------------------------------------------------
solver = SolverFactory('ipopt')
results = solver.solve(m, tee=True)

# ------------------------------------------------------------
# DATA EXTRACTION
# ------------------------------------------------------------
t_vals = sorted(m.t)

S1 = [value(m.S1[t]) for t in t_vals]
S2 = [value(m.S2[t]) for t in t_vals]
S3 = [value(m.S3[t]) for t in t_vals]
S4 = [value(m.S4[t]) for t in t_vals]
S5 = [value(m.S5[t]) for t in t_vals]
c1 = [value(m.c1[t]) for t in t_vals]
c2 = [value(m.c2[t]) for t in t_vals]
u  = [value(m.u[t]) for t in t_vals]

# ------------------------------------------------------------
# POST-PROCESSING VISUALIZATION
# ------------------------------------------------------------

# Plot 1: Main Product S4 Trajectory
plt.figure(figsize=(4.5, 4.1))
plt.plot(t_vals, S4, color='black')
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')
ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$t$ / min')
plt.ylabel('$S_4$ / mM ')
plt.yticks([0, 100, 200, 300, 400, 500], fontsize=14)
plt.xticks([0, 120, 240, 350, 480, 600, 720], fontsize=14)
plt.tight_layout()
plt.savefig('f1.pdf')
plt.show()

# Plot 2: Substrate Network Profiles (S1, S2, S3, S5)
plt.figure(figsize=(4.5, 4.1))
plt.plot(t_vals, S1, label='$S_1$')
plt.plot(t_vals, S2, label='$S_2$')
plt.plot(t_vals, S3, label='$S_3$')
plt.plot(t_vals, S5, label='$S_5$')
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')
ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$t$ / min')
plt.ylabel('$S_i$ / mM ')
plt.legend()
plt.yticks([0, 10, 20, 30, 40, 50], fontsize=14)
plt.xticks([0, 120, 240, 350, 480, 600, 720], fontsize=14)
plt.tight_layout()
plt.savefig('f2.pdf')
plt.show()

# Plot 3: Co-factor Dynamic Cyclic Profiles (c1, c2)
plt.figure(figsize=(4.5, 4.1))
plt.plot(t_vals, c1, label='$c_1$')
plt.plot(t_vals, c2, label='$c_2$')
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')
ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$t$ / min')
plt.ylabel('$S_i$ / mM ')
plt.yticks([0, 0.5, 1, 1.5, 2, 2.5], fontsize=14)
plt.xticks([0, 120, 240, 350, 480, 600, 720], fontsize=14)
plt.legend()
plt.tight_layout()
plt.savefig('f3.pdf')
plt.show()

# Plot 4: Optimal Substrate Feed Rate Strategy u(t)
plt.figure(figsize=(4.5, 4.1))
plt.plot(t_vals, u, label='u(t) / mM ')
plt.xlabel('$t$ / min')
plt.ylabel('$u$ / mM/min')
plt.yticks([0, 0.3, 0.6, 0.9, 1.2, 1.5], fontsize=14)
plt.xticks([0, 120, 240, 350, 480, 600, 720], fontsize=14)
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')
ax.tick_params(direction='out', length=5, width=1)
plt.tight_layout()
plt.savefig('f4.pdf')
plt.show()