"""
Initial Rate Kinetics Parameter Estimation & Joint Confidence Region Evaluation
================================================================================

Description:
------------
This script estimates the Michaelis-Menten kinetic parameters (Vmax, Km) for an 
enzymatic reaction step (specifically targeting step III of the cascade system).
It leverages initial rate approximations calculated via forward finite differences 
of early batch concentration trajectories across several substrate feed loadings.

Statistical analysis handles parameter errors, 95% single-parameter confidence intervals,
goodness-of-fit metrics (R², MSE, MAE), and standard covariance mapping to build a joint 
95% parameter confidence region ellipse.
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.stats import chi2

# ============================================================
# TRAJECTORY EXPERIMENTAL DATA
# ============================================================
# Sampling time array (min)
t = np.array([0, 10, 20, 30, 40, 50, 60, 70])

# Concentration time-series profiles for substrate S at different initial loadings
S_data = [
    np.array([9.99025318, 6.78598944, 4.22381507, 2.60394826, 1.40534969, 0.885985, 0.4578872, 0.25869337]),
    np.array([19.99509224, 15.2231965, 9.83667259, 6.52014741, 3.94687478, 2.29771564, 1.33836609, 0.767542]),
    np.array([50.00065009, 40.95494266, 35.48084396, 25.99122767, 19.85851197, 12.87924691, 8.61052574, 5.49791671]),
    np.array([99.97765372, 88.95970989, 79.61796638, 69.60376995, 57.49851737, 48.90891837, 38.99940212, 30.80597726]),
    np.array([150.03211406, 137.88984735, 128.2968814, 114.70602142, 106.15816397, 92.08511597, 81.17219713, 70.50926418])
]

# True targeted initial substrate loading concentrations (mmol L^-1)
S0 = np.array([10, 20, 50, 100, 150])
colors = ['#2171b5', '#2171b5', '#2171b5', '#2171b5', '#2171b5']

# ============================================================
# INITIAL RATE CALCULATION (Two-Point Forward Difference)
# ============================================================
v0 = []

for S in S_data:
    # Compute initial velocity: v0 = -dS/dt approximated over the first time interval
    rate = -(S[1] - S[0]) / (t[1] - t[0])
    v0.append(rate)

v0 = np.array(v0)

# Display calculated initial velocities
print("\n===== INITIAL RATES =====")
for s, v in zip(S0, v0):
    print(f"S0 = {s:>3} -> v0 = {v:.6f}")

# ============================================================
# MICHAELIS-MENTEN MATHEMATEST MODEL
# ============================================================
def mm_model(S, Vmax, Km):
    """
    Standard Michaelis-Menten rate expression.
    
    Parameters:
    -----------
    S : float or numpy.ndarray
        Substrate concentration (mmol L^-1).
    Vmax : float
        Maximum reaction velocity (mmol L^-1 min^-1).
    Km : float
        Michaelis affinity constant (mmol L^-1).
    """
    return (Vmax * S) / (Km + S)

# ============================================================
# NON-LINEAR LEAST-SQUARES REGRESSION (FIT)
# ============================================================
# Curve fitting initial guess: p0 = [Vmax_guess, Km_guess]
popt, pcov = curve_fit(mm_model, S0, v0, p0=[1.2, 50])

Vmax_fit, Km_fit = popt
# Extract standard errors from the diagonal components of the variance-covariance matrix
perr = np.sqrt(np.diag(pcov))

# ============================================================
# ESTIMATE SINGLE-PARAMETER CONFIDENCE INTERVALS
# ============================================================
# Large sample standard normal distribution approximation cutoff value (z = 1.96 for 95%)
tval = 1.96
Vmax_ci = [Vmax_fit - tval * perr[0], Vmax_fit + tval * perr[0]]
Km_ci   = [Km_fit   - tval * perr[1], Km_fit   + tval * perr[1]]

# Generate expected velocity outputs using calculated parameters
v0_pred = mm_model(S0, *popt)

# ============================================================
# EVALUATE FIT PERFORMANCE METRICS
# ============================================================
residuals = v0 - v0_pred

MSE = np.mean(residuals**2)        # Mean Squared Error
MAE = np.mean(np.abs(residuals))   # Mean Absolute Error

SS_res = np.sum(residuals**2)
SS_tot = np.sum((v0 - np.mean(v0))**2)
R2 = 1 - SS_res / SS_tot          # Coefficient of Determination

# Log calculated metrics
print("\n===== FIT RESULTS (Initial Rates) =====")
print(f"Vmax = {Vmax_fit:.5f} \xb1 {perr[0]:.5f}")
print(f"Km   = {Km_fit:.5f} \xb1 {perr[1]:.5f}")

print("\n===== 95% CONFIDENCE INTERVALS =====")
print(f"Vmax: [{Vmax_ci[0]:.5f}, {Vmax_ci[1]:.5f}]")
print(f"Km  : [{Km_ci[0]:.5f}, {Km_ci[1]:.5f}]")

print("\n===== GOODNESS OF FIT =====")
print(f"R\xb2  = {R2:.5f}")
print(f"MSE = {MSE:.6f}")
print(f"MAE = {MAE:.6f}")

# ============================================================
# PLOT 1: RATE-SUBSTRATE CURVE WITH ARTIFICIAL ERROR BARS
# ============================================================
plt.figure(figsize=(4.5, 4.1))

# Apply representative 5% relative variance for display validation metrics
xerr = 0.05 * S0
yerr = 0.05 * v0

plt.errorbar(
    S0, v0,
    yerr=yerr,
    fmt='o',
    color='#2171b5',
    ecolor='black',
    elinewidth=1,
    capsize=3,
    markersize=6,
    alpha=0.9
)

# Render smooth continuous curve prediction using fitted parameter bounds
S_fine = np.linspace(0, 160, 300)
v_fit = mm_model(S_fine, *popt)
plt.plot(S_fine, v_fit, color='#6baed6', linewidth=1.2)

# Graph layout and spine configurations
plt.grid(False)
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$S_3$ / mmol L$^{-1}$', fontsize=14)
plt.ylabel('$r_\\mathrm{III,0}$ / mmol L$^{-1}$ min$^{-1}$', fontsize=14)
plt.xticks([0, 50, 100, 150], fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
plt.savefig('a1.pdf')
plt.show()

# ============================================================
# JOINT CONFIDENCE ELLIPSE CONSTRUCTION (2 DOF)
# ============================================================
# Compute eigenvalues and eigenvectors to find principal axes components from the covariance matrix
vals, vecs = np.linalg.eigh(pcov)

# Chi-squared critical cutoff value for 2 degrees of freedom at a 95% coverage threshold
chi2_val = chi2.ppf(0.95, df=2)

# Generate uniform angle vectors around a full circle parameter space
theta = np.linspace(0, 2*np.pi, 200)
ellipse = np.array([np.cos(theta), np.sin(theta)])

# Map scaling factors, rotate coordinates by eigenvectors, and build baseline ellipse shape
ellipse_scaled = vecs @ np.diag(np.sqrt(vals * chi2_val)) @ ellipse

# Center the mapped ellipse coordinates around the primary optimal fitted point estimates
ellipse_x = ellipse_scaled[0] + Vmax_fit
ellipse_y = ellipse_scaled[1] + Km_fit

# ============================================================
# PLOT 2: JOINT PARAMETER REGION PROFILE
# ============================================================
plt.figure(figsize=(4.5, 4.1))

# Draw parameter ellipse contour limits and plot central optimal coordinates
plt.plot(ellipse_x, ellipse_y, color='green', linewidth=1.2)
plt.scatter(Vmax_fit, Km_fit, color='red', zorder=3)

# Synchronize plot layout parameters to match standard publication style sheets
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$V_\\mathrm{max}$ / mmol L$^{-1}$ min$^{-1}$', fontsize=14)
plt.ylabel('$K_\\mathrm{m}$ / mmol L$^{-1}$', fontsize=14)
plt.tight_layout()
plt.savefig('a2.pdf')
plt.show()