"""
Confidence ellipses comparison for Section 7.3 application example.

Uses Reaction III as an isolated sub-system (S3 as sole substrate, enzyme E_C),
which gives a clean 2-parameter (kcat_III, KS3_III) Michaelis-Menten problem.

Three scenarios — all use EXACTLY 3 progress curves to isolate the effect of
data diversity from data quantity:
  A: 3 replicates of the same initial condition
  B: 3 curves with moderate variation in S3(0)
  C: 3 curves spanning low, intermediate, and saturating S3(0)
"""

import numpy as np
from scipy.integrate import odeint
from scipy.optimize import least_squares
from scipy.stats import chi2
import matplotlib.pyplot as plt

np.random.seed(0)

# ============================================================
# TRUE PARAMETERS (Reaction III of the application example)
# ============================================================
kcat_III_true = 1.4   # mmol L^-1 min^-1 mg^-1
KS3_III_true  = 24.7  # mmol L^-1
E_C_demo      = 0.5   # mg L^-1  (dilute enzyme to spread dynamics over ~120 min)

# ============================================================
# ISOLATED REACTION III MODEL  dS3/dt = -rIII
# ============================================================
def ode_rxnIII(S3, t, kcat, Km):
    S3 = max(S3, 1e-12)
    return -kcat * E_C_demo * S3 / (Km + S3)

def simulate_rxnIII(S3_0, t_grid, kcat, Km):
    sol = odeint(ode_rxnIII, S3_0, t_grid, args=(kcat, Km), rtol=1e-9, atol=1e-11)
    return sol.flatten()

# ============================================================
# TIME GRID
# ============================================================
t_exp  = np.linspace(0, 120, 10)
t_fine = np.linspace(0, 120, 400)

# ============================================================
# DATA GENERATION SCENARIOS (3 curves per scenario)
# ============================================================
noise_level = 0.05  # 5% relative noise (consistent with parameterization scripts)

def generate_scenario_data(S3_0_list):
    data = []
    for S3_0 in S3_0_list:
        S3_true = simulate_rxnIII(S3_0, t_exp, kcat_III_true, KS3_III_true)
        noisy   = np.clip(S3_true + np.random.normal(0, noise_level * np.abs(S3_true)), 0, None)
        data.append(noisy)
    return data

# Scenario A: 3 replicates — same condition, same trajectory (only noise differs)
S3_0_A    = [50.0, 50.0, 50.0]
data_A    = generate_scenario_data(S3_0_A)

# Scenario B: 3 curves, moderate variation in S3(0)
S3_0_B    = [20.0, 50.0, 80.0]
data_B    = generate_scenario_data(S3_0_B)

# Scenario C: 3 curves spanning far below Km, near Km, and far above Km
#             S3(0) = 5 (<<Km=24.7), 50 (~2×Km), 200 (>>Km)
S3_0_C    = [5.0, 50.0, 200.0]
data_C    = generate_scenario_data(S3_0_C)

# ============================================================
# FITTING — least_squares, Jacobian-based covariance
# ============================================================
p0 = [1.0, 30.0]   # initial guess for [kcat_III, KS3_III]
bounds = ([0.01, 0.1], [10.0, 500.0])

def residuals(params, S3_0_list, data_list):
    kcat, Km = params
    res = []
    for S3_0, data in zip(S3_0_list, data_list):
        pred = simulate_rxnIII(S3_0, t_exp, kcat, Km)
        res.append(pred - data)
    return np.concatenate(res)

def fit_and_ellipse(S3_0_list, data_list, chi2_level=0.95):
    result = least_squares(
        residuals, p0, bounds=bounds, args=(S3_0_list, data_list),
        method='trf', jac='3-point'
    )
    p_fit = result.x
    J = result.jac
    n = len(result.fun)
    npar = len(p_fit)
    s2 = np.sum(result.fun**2) / (n - npar)
    cov = s2 * np.linalg.inv(J.T @ J)
    # 95% confidence ellipse
    vals, vecs = np.linalg.eigh(cov)
    chi2_val = chi2.ppf(chi2_level, df=2)
    theta = np.linspace(0, 2*np.pi, 300)
    unit_circle = np.array([np.cos(theta), np.sin(theta)])
    ellipse_pts = vecs @ np.diag(np.sqrt(vals * chi2_val)) @ unit_circle
    return p_fit, cov, ellipse_pts

fit_A, cov_A, ell_A = fit_and_ellipse(S3_0_A, data_A)
fit_B, cov_B, ell_B = fit_and_ellipse(S3_0_B, data_B)
fit_C, cov_C, ell_C = fit_and_ellipse(S3_0_C, data_C)

# ============================================================
# PRINT CORRELATION COEFFICIENTS
# ============================================================
def correlation(cov):
    return cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])

print(f"Scenario A  kcat={fit_A[0]:.3f}, Km={fit_A[1]:.2f}  |  rho = {correlation(cov_A):.3f}")
print(f"Scenario B  kcat={fit_B[0]:.3f}, Km={fit_B[1]:.2f}  |  rho = {correlation(cov_B):.3f}")
print(f"Scenario C  kcat={fit_C[0]:.3f}, Km={fit_C[1]:.2f}  |  rho = {correlation(cov_C):.3f}")

# ============================================================
# ELLIPSE AREAS
# ============================================================
def ellipse_area(cov, chi2_level=0.95):
    chi2_val = chi2.ppf(chi2_level, df=2)
    return np.pi * chi2_val * np.sqrt(np.linalg.det(cov))

area_A = ellipse_area(cov_A)
area_B = ellipse_area(cov_B)
area_C = ellipse_area(cov_C)

print()
print("95% confidence ellipse areas:")
print(f"  Scenario A: {area_A:.4f}")
print(f"  Scenario B: {area_B:.4f}  (relative to A: {area_B/area_A:.3f})")
print(f"  Scenario C: {area_C:.4f}  (relative to A: {area_C/area_A:.3f})")

# ============================================================
# FIGURE STYLE HELPER
# ============================================================
def style_ax(ax):
    ax.grid(False)
    for spine in ax.spines.values():
        spine.set_linewidth(1.2)
        spine.set_color('black')
    ax.tick_params(direction='out', length=5, width=1, labelsize=12)

# ============================================================
# FIGURE 1 — CONFIDENCE ELLIPSE COMPARISON
# ============================================================
colors_ell = ['#d73027', '#f46d43', '#08306b']   # A=red, B=orange, C=dark blue
labels_ell = [
    r'A: 3$\times$$S_3(0)=50$',
    r'B: $S_3(0)\in\{20,50,80\}$',
    r'C: $S_3(0)\in\{5,50,200\}$',
]

fig, ax = plt.subplots(figsize=(4.5, 4.1))

for (p_fit, ell, color, label) in zip(
        [fit_A, fit_B, fit_C],
        [ell_A, ell_B, ell_C],
        colors_ell, labels_ell):
    ax.plot(ell[0] + p_fit[0], ell[1] + p_fit[1],
            color=color, linewidth=1.5, label=label)
    ax.scatter(p_fit[0], p_fit[1], color=color, s=30, zorder=5)

ax.scatter(kcat_III_true, KS3_III_true, color='black', marker='+', s=80,
           zorder=6, linewidths=1.5, label='True value')

ax.set_xlabel(r'$k_{\mathrm{cat,III}}$ / mmol L$^{-1}$ min$^{-1}$ mg$^{-1}$', fontsize=12)
ax.set_ylabel(r'$K_{S_3,\mathrm{III}}$ / mmol L$^{-1}$', fontsize=12)
ax.legend(fontsize=9)
style_ax(ax)
plt.tight_layout()
plt.savefig('ellipses_comparison.pdf')
plt.show()

# ============================================================
# FIGURE 2 — PROGRESS CURVE DATA USED IN EACH SCENARIO
# (helps the reader see what changes between A, B, C)
# ============================================================
scenario_colors = [
    ['#d73027', '#d73027', '#d73027'],
    ['#4393c3', '#2166ac', '#08306b'],
    ['#74c476', '#238b45', '#00441b'],
]
scenario_S3_0s  = [S3_0_A, S3_0_B, S3_0_C]
scenario_data   = [data_A,  data_B,  data_C]
scenario_labels = ['A', 'B', 'C']

fig, axes = plt.subplots(1, 3, figsize=(9, 3.5), sharey=False)

for ax, S3_0_list, data_list, sc_colors, sc_label in zip(
        axes, scenario_S3_0s, scenario_data, scenario_colors, scenario_labels):
    for S3_0, data, color in zip(S3_0_list, data_list, sc_colors):
        pred = simulate_rxnIII(S3_0, t_fine, kcat_III_true, KS3_III_true)
        ax.scatter(t_exp, data, color=color, s=20, zorder=3)
        ax.plot(t_fine, pred, color=color, linewidth=1.2,
                label=f'$S_3(0)={S3_0}$')
    ax.set_xlabel('$t$ / min', fontsize=11)
    ax.set_ylabel('$S_3$ / mmol L$^{-1}$', fontsize=11)
    ax.set_title(f'Scenario {sc_label}', fontsize=11)
    ax.legend(fontsize=8)
    style_ax(ax)

plt.tight_layout()
plt.savefig('ellipses_scenarios_data.pdf')
plt.show()
