"""
Integral Method Parameter Estimation via Global Non-Linear Regression
====================================================================

Description:
------------
This script implements the integral method of kinetic analysis to estimate 
Michaelis-Menten parameters (Vmax, Km) for an enzymatic step (e.g., step III).
Unlike the differential initial rates approach, this script simultaneously fits 
the full dynamic concentration trajectories (progress curves) across multiple 
independent batch experiments with varying initial substrate loadings (S0).

Methodology:
------------
1. Integrates the underlying governing ordinary differential equation (ODE) 
   dynamically using `scipy.integrate.odeint`.
2. Flattens and concatenates multiple time-series data streams into a single 
   1D target vector (`S_all`).
3. Uses a virtual index array (`x_dummy`) to handle multi-trajectory evaluation 
   within the standard `scipy.optimize.curve_fit` API.
4. Computes parameter uncertainty metrics, joint 95% confidence ellipses, 
   and standard residual statistical indicators (R², MSE, MAE).
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from scipy.optimize import curve_fit
from scipy.stats import chi2

# ============================================================
# TRAJECTORY EXPERIMENTAL DATA & LAYOUT SETUP
# ============================================================
# Discrete experimental sampling points (min)
t = np.array([0, 10, 20, 30, 40, 50, 60, 70])

# Substrate progress curves (concentration vs. time) for 5 distinct batch runs
S_data = [
    np.array([9.99025318, 6.78598944, 4.22381507, 2.60394826, 1.40534969, 0.885985, 0.4578872, 0.25869337]),
    np.array([19.99509224, 15.2231965, 9.83667259, 6.52014741, 3.94687478, 2.29771564, 1.33836609, 0.767542]),
    np.array([50.00065009, 40.95494266, 35.48084396, 25.99122767, 19.85851197, 12.87924691, 8.61052574, 5.49791671]),
    np.array([99.97765372, 88.95970989, 79.61796638, 69.60376995, 57.49851737, 48.90891837, 38.99940212, 30.80597726]),
    np.array([150.03211406, 137.88984735, 128.2968814, 114.70602142, 106.15816397, 92.08511597, 81.17219713, 70.50926418])
]

# Nominal initial concentrations matching the trajectories (mmol L^-1)
S0_list = [10, 20, 50, 100, 150]

# Publication-quality monochromatic blue sequential color map for plotting
colors = ['#041f3d', '#08306b', '#2171b5', '#6baed6', '#deebf7']

# Flatten multi-batch data matrices into a continuous 1D array for least-squares residuals
S_all = np.concatenate(S_data)
# Generate a dummy independent variable array matching the total entries required by curve_fit
x_dummy = np.arange(len(S_all))

# ============================================================
# DYNAMIC MASS BALANCE & CONCATENATED PREDICTION MODEL
# ============================================================
def simulate_curve(S0, t, Vmax, Km):
    """
    Solves the governing Michaelis-Menten batch mass balance ODE.
    
    dS/dt = - (Vmax * S) / (Km + S)
    """
    def dSdt(S, t):
        return -Vmax * S / (Km + S)
    return odeint(dSdt, S0, t).flatten()

def model(dummy_x, Vmax, Km):
    """
    Global wrapping model that generates back-to-back trajectory simulations
    for all experimental configurations, flattening them into a single response vector.
    """
    S_out = []
    for S0 in S0_list:
        S_out.append(simulate_curve(S0, t, Vmax, Km))
    return np.concatenate(S_out)

# ============================================================
# PARAMETER ESTIMATION (NON-LINEAR REGRESSION)
# ============================================================
# Execute curve fitting utilizing a robust parameter seed estimation: p0=[Vmax, Km]
popt, pcov = curve_fit(model, x_dummy, S_all, p0=[0.1, 50])
Vmax_fit, Km_fit = popt

# Extract single standard parameter errors from the covariance diagonal entries
perr = np.sqrt(np.diag(pcov))

# Determine independent 95% single-parameter confidence intervals
tval = 1.96
Vmax_ci = [Vmax_fit - tval * perr[0], Vmax_fit + tval * perr[0]]
Km_ci   = [Km_fit   - tval * perr[1], Km_fit   + tval * perr[1]]

# Print parameters to standard logging terminal
print("\n===== FIT RESULTS =====")
print(f"Vmax = {Vmax_fit:.5f} \xb1 {perr[0]:.5f}")
print(f"Km   = {Km_fit:.5f} \xb1 {perr[1]:.5f}")

print("\n===== 95% CONFIDENCE INTERVALS =====")
print(f"Vmax: [{Vmax_ci[0]:.5f}, {Vmax_ci[1]:.5f}]")
print(f"Km  : [{Km_ci[0]:.5f}, {Km_ci[1]:.5f}]")

# ============================================================
# PLOT 1: DYNAMIC PROGRESS CURVES & GLOBAL FIT REGRESSION
# ============================================================
plt.figure(figsize=(4.5, 4.1))
t_fine = np.linspace(0, 70, 200) # High-resolution time-grid for smooth curve rendering

for i, S in enumerate(S_data):
    # Overlay actual experimental sampling entries as data points
    plt.scatter(t, S, color=colors[i], s=25, alpha=0.7)

    # Compute continuous dynamic profile using calculated optimal parameters
    S_fit = simulate_curve(S0_list[i], t_fine, Vmax_fit, Km_fit)
    plt.plot(t_fine, S_fit, color='#6baed6', linewidth=1.2)

# Graph styling layout
plt.grid(False)
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)

# Generate custom labels for legend entry identification
for i, S0 in enumerate(S0_list):
    plt.scatter([], [], color=colors[i], label=f'{S0} mmol L$^{{-1}}$')

plt.legend(fontsize=11)
plt.xlabel('$t$ / min', fontsize=14)
plt.ylabel('$S_3$ / mmol L$^{-1}$', fontsize=14)
plt.xticks([0, 25, 50, 75], fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
plt.savefig('a3.pdf')
plt.show()

# ============================================================
# PLOT 2: JOINT 95% PARAMETER CONFIDENCE REGION (ELLIPSE)
# ============================================================
# Calculate eigenvalues/eigenvectors from covariance matrix to define ellipse orientation
vals, vecs = np.linalg.eigh(pcov)
# Chi-squared critical cutoff for 2 degrees of freedom (Joint estimation)
chi2_val = chi2.ppf(0.95, df=2)

# Generate parametric angular circle array
theta = np.linspace(0, 2*np.pi, 200)
ellipse = np.array([np.cos(theta), np.sin(theta)])

# Scale and rotate geometric ellipse framework via statistical covariance parameters
ellipse_scaled = vecs @ np.diag(np.sqrt(vals * chi2_val)) @ ellipse

# Position coordinate transformation centering over primary regression estimates
ellipse_x = ellipse_scaled[0] + Vmax_fit
ellipse_y = ellipse_scaled[1] + Km_fit

plt.figure(figsize=(4.5, 4.1))
plt.plot(ellipse_x, ellipse_y, color='green', linewidth=1.2)
plt.scatter(Vmax_fit, Km_fit, color='red')

# Axis formatting
ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)
plt.xlabel('$V_\\mathrm{max}$ / mM min$^{-1}$', fontsize=14)
plt.ylabel('$K_\\mathrm{m}$ / mM', fontsize=14)
plt.tight_layout()
plt.savefig('a4.pdf')
plt.show()

# ============================================================
# GOODNESS-OF-FIT STATISTICAL EVALUATION
# ============================================================
# Generate overall model predictions matching total data length
S_pred = model(x_dummy, *popt)
residuals = S_all - S_pred

# Compute descriptive error metrics
MSE = np.mean(residuals**2)       # Mean Squared Error
MAE = np.mean(np.abs(residuals))  # Mean Absolute Error

# Compute variance metrics for coefficient of determination (R2) mapping
SS_res = np.sum(residuals**2)
SS_tot = np.sum((S_all - np.mean(S_all))**2)
R2 = 1 - SS_res / SS_tot

# Output performance metrics to the terminal
print("\n===== GOODNESS OF FIT =====")
print(f"R\xb2  = {R2:.5f}")
print(f"MSE = {MSE:.5f}")
print(f"MAE = {MAE:.5f}")