import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# ============================================================
# KINETIC PARAMETERS & REFERENCE INITIAL CONDITIONS
# ============================================================
kcat_I   = 1.2
kcat_II  = 0.3
kcat_III = 1.4
kcat_IV  = 0.5

KS1_I    = 30.2
Kc1_I    = 1.1
KiS1_I   = 4.3

KS2_II      = 8.2
KiS1_II     = 8.0
KiS5_II     = 14.3

KS3_III = 24.7

KS5_IV     = 15.2
Kc2_IV     = 3.1
KiS1_IV    = 8.3

E_A = 15.0
E_B = 15.0
E_C = 15.0
E_D = 15.0

# Initial concentrations: [S1, S2, S3, S4, c1, c2, S5]
y0 = [50, 0, 0, 0, 2, 0, 50]
Sin = y0  # Inlet feed concentration matches the initial batch layout

# ============================================================
# KINETICS FUNCTION
# ============================================================
def rates(S1, S2, S3, S4, c1, c2, S5):
    """Calculates individual reaction rates for the cascade network."""
    rI = (kcat_I * E_A * S1 * c1) / ((KS1_I + S1) * (Kc1_I + c1) * (1 + S1 / KiS1_I))
    rII = (kcat_II * E_B * S2) / ((KS2_II + S2) * (1 + S1 / KiS1_II) * (1 + S5 / KiS5_II))
    rIII = (kcat_III * E_C * S3) / (KS3_III + S3)
    rIV = (kcat_IV * E_D * S5 * c2) / ((KS5_IV + S5) * (Kc2_IV + c2) * (1 + S1 / KiS1_IV))
    return rI, rII, rIII, rIV

# ============================================================
# SENSITIVITY SWEEP SETUP
# ============================================================
V = 30.0                              # Constant CSTR volume (L)
F_values = np.linspace(0.1, 5.0, 4000) # Parameter sweep: volumetric flowrate (L min^-1)
t_long = np.linspace(0, 900, 400)     # Long time horizon to ensure true steady-state tracking

tol = 1e-4                            # Derivative threshold to define steady state (|dx/dt| < tol)
t_ss_list = []                        # Container for computed settling times

# ============================================================
# DYNAMIC SENSITIVITY LOOP
# ============================================================
for F in F_values:

    def cstr_model_F(y, t):
        """Transient CSTR mass balances at a specific volumetric flow rate F."""
        S1, S2, S3, S4, c1, c2, S5 = y
        rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)

        return [
            F/V*(Sin[0]-S1) - rI,
            F/V*(Sin[1]-S2) + rI - rII,
            F/V*(Sin[2]-S3) + rII - rIII,
            F/V*(Sin[3]-S4) + rIII,
            F/V*(Sin[4]-c1) - rI + rIV,
            F/V*(Sin[5]-c2) + rI - rIV,
            F/V*(Sin[6]-S5) - rIV
        ]

    # Simulate dynamic trajectory
    sol = odeint(cstr_model_F, y0, t_long)

    # Compute numerical gradients over time (dx/dt) for each component profile
    dsol_dt = np.gradient(sol, axis=0) / np.gradient(t_long)[:, None]

    # Default to final horizon if stability threshold isn't satisfied early
    t_ss = t_long[-1]

    # Evaluate the time step where all concentration derivatives drop below the tolerance
    for i in range(len(t_long)):
        if np.max(np.abs(dsol_dt[i])) < tol:
            t_ss = t_long[i]
            break

    t_ss_list.append(t_ss)

t_ss_list = np.array(t_ss_list)

# ============================================================
# PLOT RESULTS
# ============================================================
plt.figure(figsize=(4.5, 4.1))
plt.plot(F_values, t_ss_list, '-', color='#08306b')

ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)

plt.xlabel('$F$ / L min$^{-1}$', fontsize=14)
plt.ylabel('$t_{rss}$ / min', fontsize=14)
plt.xticks(np.linspace(0, 5, 5), fontsize=14)
plt.grid(False)
plt.tight_layout()
plt.savefig('CSTR_time_to_SS_vs_F.pdf')
plt.show()