import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from scipy.stats import norm
from skopt import gp_minimize
from skopt.space import Real
from skopt.learning import GaussianProcessRegressor
from skopt.learning.gaussian_process.kernels import Matern

# ============================================================
# KINETIC PARAMETERS & BASE BR CORE
# ============================================================
kcat_I   = 1.2
kcat_II  = 0.3
kcat_III = 1.4
kcat_IV  = 0.5

KS1_I    = 30.2
Kc1_I    = 1.1
KiS1_I   = 4.3

KS2_II      = 8.2
KiS1_II     = 8.0
KiS5_II     = 14.3

KS3_III = 24.7

KS5_IV     = 15.2
Kc2_IV     = 3.1
KiS1_IV    = 8.3

E_A, E_B, E_C, E_D = 15.0, 15.0, 15.0, 15.0

def rates(S1, S2, S3, S4, c1, c2, S5):
    """Calculates cascade reaction velocity trajectories."""
    rI = (kcat_I * E_A * S1 * c1) / ((KS1_I + S1) * (Kc1_I + c1) * (1 + S1 / KiS1_I))
    rII = (kcat_II * E_B * S2) / ((KS2_II + S2) * (1 + S1 / KiS1_II) * (1 + S5 / KiS5_II))
    rIII = (kcat_III * E_C * S3) / (KS3_III + S3)
    rIV = (kcat_IV * E_D * S5 * c2) / ((KS5_IV + S5) * (Kc2_IV + c2) * (1 + S1 / KiS1_IV))
    return rI, rII, rIII, rIV

def br_model(y, t):
    """Standard dynamic mass balances for a Batch Reactor system."""
    S1, S2, S3, S4, c1, c2, S5 = y
    rI, rII, rIII, rIV = rates(S1, S2, S3, S4, c1, c2, S5)
    return [-rI, rI-rII, rII-rIII, rIII, -rI+rIV, rI-rIV, -rIV]

# ============================================================
# OPTIMIZATION PROBLEM BOUNDS & OBJECTIVE
# ============================================================
t_eval = np.linspace(0, 60, 200)  # Evaluation window up to the 60 min harvest horizon

def objective(x):
    """
    Black-box objective function evaluating the network performance.
    
    Transforms the maximization of product S4(t=60) into a minimization task 
    via negative sign inversion.
    """
    S0 = x[0]

    # Construct initial condition state vector enforcing equimolar feed: S1(0) = S5(0) = S0
    y0_mod = [S0, 0, 0, 0, 2, 0, S0]

    # Integrate batch equations
    sol = odeint(br_model, y0_mod, t_eval)
    S4_60 = sol[-1, 3]   # Extract final concentration profile of component S4

    return -S4_60        # Return negative value to match minimization requirements

# Define continuous search space for initial substrate allocation (mmol L^-1)
space = [Real(0.1, 100.0, name='S0')]

# ============================================================
# EXECUTE BAYESIAN OPTIMIZATION
# ============================================================
result = gp_minimize(
    func=objective,
    dimensions=space,
    n_calls=40,               # Total active iteration limit budget
    n_initial_points=10,      # Initial Latin Hypercube or Random exploration seed points
    acq_func='EI',            # Expected Improvement exploitation/exploration heuristic
    random_state=42
)

# Output Summary results to log
S_opt = result.x[0]
S4_opt = -result.fun
print("=== OPTIMIZATION RESULT ===")
print(f"Optimal S1 = S5 = {S_opt:.3f} mmol/L")
print(f"S4 at 60 min = {S4_opt:.4f} mmol/L\n")

# ============================================================
# POST-PROCESSING: GAUSSIAN PROCESS RECONSTRUCTION
# ============================================================
X = np.array(result.x_iters).reshape(-1, 1)
y = np.array(result.func_vals)  # Minimized cost metric entries (-S4)
y_best = np.min(y)

# Instantiate a specialized Kriging surrogate with a smooth Matérn 5/2 covariance kernel
kernel = Matern(nu=2.5)
gp = GaussianProcessRegressor(kernel=kernel, normalize_y=True)
gp.fit(X, y)

# Construct fine prediction resolution grid array across parameter space bounds
S_grid = np.linspace(0.1, 100, 400).reshape(-1, 1)
mu, sigma = gp.predict(S_grid, return_std=True)
sigma = np.maximum(sigma, 1e-9) # Impose low bound guard rail against divide-by-zero errors

# Calculate standard Expected Improvement (EI) profile values across the domain space
Z = (y_best - mu) / sigma
EI = (y_best - mu) * norm.cdf(Z) + sigma * norm.pdf(Z)
EI = np.maximum(EI, 0)

# ============================================================
# SURROGATE VISUALIZATION GRAPH
# ============================================================
plt.rcParams.update({
    "font.size": 12,
    "axes.linewidth": 1.2,
    "xtick.direction": "out",
    "ytick.direction": "out"
})

plt.figure(figsize=(4.5, 4.1))

# Re-invert the optimization signs (-mu, -y) to plot true positive product concentration
plt.plot(S_grid, -mu, color='black', lw=2, label='GP mean')
plt.fill_between(
    S_grid.ravel(),
    -mu - 1.96 * sigma,
    -mu + 1.96 * sigma,
    color='gray',
    alpha=0.3,
    label='95% CI'
)
plt.scatter(X, -y, c='red', s=20, label='Evaluations')

ax = plt.gca()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

for spine in ax.spines.values():
    spine.set_linewidth(1.2)
    spine.set_color('black')

ax.tick_params(direction='out', length=5, width=1)

plt.xlabel('$S_{1}(0)$ = $S_{5}(0)$ / mmol L$^{-1}$')
plt.ylabel('$S_{4}$(60 min) / mmol L$^{-1}$')
plt.legend()
plt.tight_layout()
plt.savefig("GP_surrogate_batch.pdf")
plt.show()