import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import brentq
import matplotlib.pyplot as plt

from skopt import gp_minimize
# ============================================================
# CONSTANTS
# ============================================================
R = 8.314
R_gas = 0.08206
T_fixed = 308.15

# ============================================================
# PAPER PARAMETERS (CABEQ)
# ============================================================
Ea = 35300
T_star = 414.6
KM = 3.21e-3
KES1 = 7.57e-7
KES2 = 1.27e-8
KP = 1.22e-2
parameter_kinetic = 40.04
urease = 25 * 0.015 / 1000 * parameter_kinetic

# ============================================================
# REACTOR
# ============================================================
V_liq = 0.015
V_gas = 0.005
kla_CO2 = 0.0092
kla_NH3 = 0.0092

# ============================================================
# TEMPERATURE CONSTANTS
# ============================================================
def temperature_constants(T):
    KNH3 = np.exp(191.97 - 8451.61/T - 31.4335*np.log(T) + 0.0152123*T)
    KCO2 = np.exp(2767.92 - 80063.5/T - 478.653*np.log(T) + 0.714984*T)
    KHCO3 = np.exp(12.405 - 6286.89/T - 0.050628*T)
    KH2O = np.exp(14.01708 - 10294.83/T - 0.039282*T)
    H_CO2 = np.exp(-(1082.37 - 34417.2/T - 182.28*np.log(T) + 0.25159*T))
    H_NH3 = np.exp(-(160.559 - 8621.06/T - 25.6767*np.log(T) + 0.035388*T))
    return KNH3, KCO2, KHCO3, KH2O, H_CO2, H_NH3

KNH3, KCO2, KHCO3, KH2O, H_CO2, H_NH3 = temperature_constants(T_fixed)

# ============================================================
# pH SOLVER
# ============================================================
def electroneutrality(pH, TA, TC):
    H = 10**(-pH)
    OH = KH2O / H
    NH4 = TA * (H*KNH3)/(KNH3*H + KH2O)
    HCO3 = TC * (KCO2*H)/(H**2 + KCO2*H + KCO2*KHCO3)
    CO3 = TC * (KCO2*KHCO3)/(H**2 + KCO2*H + KCO2*KHCO3)
    return NH4 + H - (HCO3 + 2*CO3 + OH)

def solve_pH(TA, TC):
    TA = max(TA, 1e-16)
    TC = max(TC, 1e-16)
    return brentq(electroneutrality, 2, 12, args=(TA, TC))

# ============================================================
# REACTOR MODEL
# ============================================================
def model(t, y):
    UREA, TA, TC, n_NH3_gas, n_CO2_gas = y
    pH = solve_pH(TA, TC)
    H = 10**(-pH)
    P_CO2 = n_CO2_gas * R_gas * T_fixed / V_gas
    P_NH3 = n_NH3_gas * R_gas * T_fixed / V_gas
    CO2_aq = TC * H**2 / (H**2 + KCO2*H + KCO2*KHCO3)
    NH3_free = TA * KH2O / (KNH3*H + KH2O)
    rCO2 = kla_CO2 * (CO2_aq - H_CO2 * P_CO2)
    rNH3 = kla_NH3 * (NH3_free - H_NH3 * P_NH3)
    S = max(UREA, 1e-16)
    P = max(NH3_free, 1e-16)
    kT = np.exp(-Ea/R * (1/T_fixed - 1/T_star))
    den_substrate = KM + S
    den_product = 1 + P / KP
    den_pH = 1 + H/KES1 + KES2/H
    ve = kT * urease * S / (den_substrate * den_product * den_pH)
    dUREA = -ve
    dTA = 2*ve - rNH3
    dTC = ve - rCO2
    dn_NH3 = rNH3 * V_liq
    dn_CO2 = rCO2 * V_liq
    return [dUREA, dTA, dTC, dn_NH3, dn_CO2]

# ============================================================
# SIMULATION FUNCTION USING INITIAL CO2 PRESSURE
# ============================================================
def simulate(P_CO2_init):
    n_CO2_init = P_CO2_init * V_gas / (R_gas * T_fixed)
    y0 = [0.03, 0.0, 0.0, 0.0, n_CO2_init]
    sol = solve_ivp(model, (0, 60), y0, method="BDF", t_eval=[60], rtol=1e-6, atol=1e-9)
    if not sol.success:
        return 1e6
    return sol.y[0, -1]

def objective(x):
    return simulate(x[0])

# ============================================================
# BAYESIAN OPTIMIZATION
# ============================================================
space = [(0, 3.5)]
res = gp_minimize(objective, space, n_calls=15, random_state=42, acq_func="EI")

optimal_P_CO2 = res.x[0]
optimal_UREA = res.fun
optimal_n_CO2 = optimal_P_CO2 * V_gas / (R_gas * T_fixed)

print("\n========== BAYESIAN OPTIMIZATION RESULT ==========")
print(f"Optimal initial CO2 pressure: {optimal_P_CO2:.6f} atm")
print(f"Corresponding moles of CO2: {optimal_n_CO2:.6e} mol")
print(f"Final urea concentration: {optimal_UREA:.6f} mol/L")

# ============================================================
# SIMULATE OPTIMAL TRAJECTORY
# ============================================================
y0_opt = [0.03, 0.0, 0.0, 0.0, optimal_n_CO2]
t_eval = np.linspace(0, 60, 300)
sol = solve_ivp(model, (0, 60), y0_opt, method="BDF", t_eval=t_eval, rtol=1e-6, atol=1e-9)
time = sol.t
UREA, TA, TC, n_NH3_gas, n_CO2_gas = sol.y
pH = [solve_pH(TA[i], TC[i]) for i in range(len(time))]
P_CO2 = n_CO2_gas * R_gas * T_fixed / V_gas

# ============================================================
# PLOTTING ALL FIGURES NICELY
# ============================================================
plt.rcParams.update({'font.size': 12})

# 1. Reactor time evolution
plt.figure(figsize=(10,6))
plt.plot(time, UREA, label="Urea")
plt.plot(time, TA, label="Total Ammonia")
plt.plot(time, TC, label="Total Carbon")
plt.ylabel("Concentration [mol/L]")
plt.xlabel("Time [min]")
plt.title("Reactor Species Evolution")
plt.legend()
plt.grid(True)
plt.show()

plt.figure(figsize=(10,4))
plt.plot(time, P_CO2, label="CO2 pressure [atm]")
plt.plot(time, pH, label="pH")
plt.ylabel("Pressure / pH")
plt.xlabel("Time [min]")
plt.title("CO2 Pressure and pH Evolution")
plt.legend()
plt.grid(True)
plt.show()

# 2. Bayesian Optimization Convergence
x_iters = np.array(res.x_iters).flatten()
func_vals = np.array(res.func_vals).flatten()
best_so_far = np.minimum.accumulate(func_vals)

plt.figure(figsize=(10,4))
plt.plot(best_so_far, marker='o', label='Best urea found')
plt.xlabel('Iteration')
plt.ylabel('Final urea concentration [mol/L]')
plt.title('Bayesian Optimization Convergence')
plt.grid(True)
plt.legend()
plt.show()

# Evaluated Points vs Objective
plt.figure(figsize=(10,4))
plt.scatter(x_iters, func_vals, c='blue', label='Evaluated points')
plt.axhline(optimal_UREA, color='red', linestyle='--', label='Optimal urea')
plt.xlabel('Initial CO2 Pressure [atm]')
plt.ylabel('Final urea [mol/L]')
plt.title('Evaluated Points vs Objective')
plt.grid(True)
plt.legend()
plt.show()

# 4. Surrogate GP mean ± 95% confidence
try:
    from skopt.learning import GaussianProcessRegressor
    from skopt.learning.gaussian_process.kernels import Matern
    X = np.array(x_iters).reshape(-1,1)
    y = np.array(func_vals)
    gp = GaussianProcessRegressor(kernel=Matern(length_scale=1.0), alpha=1e-9, normalize_y=True)
    gp.fit(X, y)
    X_plot = np.linspace(0, max(x_iters)*1.2, 500).reshape(-1,1)
    y_mean, y_std = gp.predict(X_plot, return_std=True)
    plt.figure(figsize=(10,4))
    plt.plot(X_plot, y_mean, 'k-', label='GP mean')
    plt.fill_between(X_plot.flatten(), y_mean-1.96*y_std, y_mean+1.96*y_std, alpha=0.3, label='95% CI')
    plt.scatter(X, y, c='red', label='Evaluations')
    plt.axvline(optimal_P_CO2, color='green', linestyle='--', label='Optimal CO2')
    plt.xlabel('Initial CO2 Pressure [atm]')
    plt.ylabel('Final urea [mol/L]')
    plt.title('Gaussian Process Surrogate Model')
    plt.legend()
    plt.grid(True)
    plt.show()
except ImportError:
    print("Skipping GP plot: sklearn required.")
