"""
Kinetic Model Validation Framework: Training Fit and Unseen Test Evaluation
===========================================================================

Description:
------------
This script implements a validation pipeline for a complex network of enzymatic 
cascade reactions. It benchmarks the predictive performance of a mechanistic, 
system-of-ODEs model against synthetic experimental data containing artificial 
Gaussian noise.

Methodology:
------------
1. Data Generation: Dynamic trajectories of target product S4 are computed via
   `scipy.integrate.odeint`. Heteroscedastic relative Gaussian noise (5%) is 
   added at specific sampling nodes to mimic real bench top observations.
2. Training/Validation Split: 
   - Training Set: Explicit concentrations at low and high bounds (50 and 100 mmol/L).
   - Unseen Test Set: An intermediate concentration (75 mmol/L) to verify interpolation
     and true validation capability.
3. Statistics: Linear interpolation (`np.interp`) pairs continuous time domain model 
   predictions with discrete empirical collection points for the calculation of 
   the Coefficient of Determination (R²), Mean Absolute Error (MAE), and Mean Squared Error (MSE).
"""

import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# Set random seed to guarantee reproducibility of the synthetic data generation
np.random.seed(42)

# ============================================================
# KINETIC FIXED PARAMETERS
# ============================================================
# Kinetic coefficients and inhibition variables derived from tutorial network data
kcat_I   = 1.2;  KS1_I   = 30.2;  Kc1_I  = 1.1;  KiS1_I  = 4.3
kcat_II  = 0.3;  KS2_II  = 8.2;   KiS1_II = 8.0;  KiS5_II = 14.3
kcat_III = 1.4;  KS3_III = 24.7
kcat_IV  = 0.5;  KS5_IV  = 15.2;  Kc2_IV  = 3.1;  KiS1_IV = 8.3

# Equimolar total enzyme loading levels (mmol L^-1)
E_A = E_B = E_C = E_D = 15.0

# ============================================================
# CORE MECHANISTIC TRANSIENT SYSTEM (ODEs)
# ============================================================
def rates(S1, S2, S3, S4, c1, c2, S5):
    """
    Calculates individual reaction velocity trajectories.
    
    Includes a low bound guard rail max(X, 1e-9) to avoid numerical instabilities 
    and division-by-zero errors in non-linear denominator terms during integration.
    """
    S1 = max(S1, 1e-9); S2 = max(S2, 1e-9); S3 = max(S3, 1e-9)
    c1 = max(c1, 1e-9); c2 = max(c2, 1e-9); S5 = max(S5, 1e-9)
    
    rI   = kcat_I   * E_A * S1 * c1 / ((KS1_I + S1)   * (Kc1_I + c1)   * (1 + S1/KiS1_I))
    rII  = kcat_II  * E_B * S2       / ((KS2_II + S2)  * (1 + S1/KiS1_II) * (1 + S5/KiS5_II))
    rIII = kcat_III * E_C * S3       / (KS3_III + S3)
    rIV  = kcat_IV  * E_D * S5 * c2 / ((KS5_IV + S5)  * (Kc2_IV + c2)   * (1 + S1/KiS1_IV))
    return rI, rII, rIII, rIV

def batch_ode(y, t):
    """Defines the governing system of mass balance differential equations."""
    S1, S2, S3, S4, c1, c2, S5 = y
    rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)
    return [-rI, rI-rII, rII-rIII, rIII, -rI+rIV, rI-rIV, -rIV]

def simulate(S1_0, S5_0=None):
    """
    Solves the dynamic batch trajectory over the continuous time grid.
    
    Enforces equimolar feed conditions if S5_0 is omitted. Default co-factor
    initial condition c1(0) is fixed at 2.0 mmol L^-1.
    """
    if S5_0 is None:
        S5_0 = S1_0
    y0 = [S1_0, 0.0, 0.0, 0.0, 2.0, 0.0, S5_0]
    return y0, odeint(batch_ode, y0, t_fine, rtol=1e-8, atol=1e-10)

# ============================================================
# TIME GRID DEFINITIONS
# ============================================================
# Discrete data acquisition timestamps mimicking lab measurements (min)
t_exp  = np.array([0, 15, 30, 45, 60, 75, 90, 105, 120], dtype=float)
# High-density grid for smooth resolution visualization tracking
t_fine = np.linspace(0, 120, 300)

# ============================================================
# SYNTHETIC REACTION DATA GENERATION
# ============================================================
noise_level = 0.05  # Imposes 5% relative variance on the measurements

def generate_data(S1_0):
    """
    Generates noisy benchmark data profiles for target product S4.
    
    Interpolates high-resolution model curves down to laboratory time points,
    adds relative Gaussian error, and clips outputs to prevent negative quantities.
    """
    _, sol = simulate(S1_0)
    S4_true = np.interp(t_exp, t_fine, sol[:, 3])
    return np.clip(S4_true + np.random.normal(0, noise_level * np.abs(S4_true)), 0, None)

# Initialize Training Dataset: Boundary system conditions
S1_0_train = [50.0, 100.0]
data_train  = [generate_data(S1_0) for S1_0 in S1_0_train]

# Initialize Test Dataset: Unseen, intermediate operational condition
S1_0_test = 75.0
data_test  = generate_data(S1_0_test)

# ============================================================
# PRE-COMPUTE NOMINAL VALIDATION CURVES
# ============================================================
preds_train = [simulate(S1_0)[1][:, 3] for S1_0 in S1_0_train]
pred_test   = simulate(S1_0_test)[1][:, 3]

def pred_at_exp(S1_0):
    """Helper method to return target values aligned with experimental nodes."""
    return np.interp(t_exp, t_fine, simulate(S1_0)[1][:, 3])

# ============================================================
# STATISTICAL VALIDATION METRICS
# ============================================================
def compute_metrics(data_list, S1_0_list):
    """
    Calculates residual fit quality indicators across evaluated runs.
    
    Returns: R2 (Coefficient of Determination), MAE, and MSE.
    """
    all_data = np.concatenate([d for d in data_list])
    all_pred = np.concatenate([pred_at_exp(S1_0) for S1_0 in S1_0_list])
    res = all_data - all_pred
    
    R2  = 1 - np.sum(res**2) / np.sum((all_data - np.mean(all_data))**2)
    MAE = np.mean(np.abs(res))
    MSE = np.mean(res**2)
    return R2, MAE, MSE

# Execute metrics evaluation pipeline
R2_train, MAE_train, MSE_train = compute_metrics(data_train, S1_0_train)
R2_test,  MAE_test,  MSE_test  = compute_metrics([data_test], [S1_0_test])

print("===== TRAINING SET =====")
print(f"R2 = {R2_train:.4f},  MAE = {MAE_train:.3f} mmol/L,  MSE = {MSE_train:.3f} (mmol/L)^2")
print("\n===== UNSEEN TEST EXPERIMENT =====")
print(f"R2 = {R2_test:.4f},  MAE = {MAE_test:.3f} mmol/L,  MSE = {MSE_test:.3f} (mmol/L)^2")

# ============================================================
# VISUALIZATION UTILITIES
# ============================================================
def style_ax(ax):
    """Applies standardized publication plot outlines and axes parameters."""
    ax.grid(False)
    for spine in ax.spines.values():
        spine.set_linewidth(1.2)
        spine.set_color('black')
    ax.tick_params(direction='out', length=5, width=1, labelsize=14)

# ------------------------------------------------------------
# FIGURE 1: TRAINING SET PERFORMANCE (COMBINED DATA PANEL)
# ------------------------------------------------------------
colors_train = ['#08306b', '#2171b5']
markers      = ['o', 's']

fig, ax = plt.subplots(figsize=(4.5, 4.1))

for i, (S1_0, data, pred) in enumerate(zip(S1_0_train, data_train, preds_train)):
    label_data  = f'$S_1(0)={int(S1_0)}$, data'
    ax.scatter(t_exp,  data, color=colors_train[i], s=25, marker=markers[i],
               zorder=3, label=label_data)
    ax.plot(t_fine, pred, color=colors_train[i], linewidth=1.5)

# Append proxy trace lines for proper visualization labels within legends
for i, S1_0 in enumerate(S1_0_train):
    ax.plot([], [], color=colors_train[i], linewidth=1.5,
            label=f'$S_1(0)={int(S1_0)}$, model')

ax.set_xlabel('$t$ / min', fontsize=14)
ax.set_ylabel('$S_4$ / mmol L$^{-1}$', fontsize=14)
ax.legend(fontsize=9, ncol=2)
style_ax(ax)
plt.tight_layout()
plt.savefig('val_training_fit.pdf')
plt.show()

# ------------------------------------------------------------
# FIGURE 2: UNSEEN GENERALIZATION TEST PREDICTION
# ------------------------------------------------------------
fig, ax = plt.subplots(figsize=(4.5, 4.1))

ax.scatter(t_exp,  data_test, color='#6baed6', s=30, zorder=3,
           label='Unseen experiment')
ax.plot(t_fine, pred_test, color='#08306b', linewidth=1.5,
        label='Model prediction')

ax.set_xlabel('$t$ / min', fontsize=14)
ax.set_ylabel('$S_4$ / mmol L$^{-1}$', fontsize=14)
ax.legend(fontsize=11)
style_ax(ax)
plt.tight_layout()
plt.savefig('val_unseen_prediction.pdf')
plt.show()