import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

# ---------------------------
# CONSTANTS
# ---------------------------
R = 8.314
R_gas = 0.08206
T_fixed = 308.15

# ---------------------------
# PAPER PARAMETERS
# ---------------------------
Ea = 35300
T_star = 414.6
KM = 3.21e-3
KES1 = 7.57e-7
KES2 = 1.27e-8
KP = 1.22e-2

# ---------------------------
# REACTOR
# ---------------------------
V_liq = 0.015
V_gas = 0.005
kla_NH3 = 0.05  # fixed

# ---------------------------
# EXPERIMENTAL DATA
# ---------------------------
initial_CO2_exp = [0.0, 0.5, 1.0, 2.1, 3.5]  # atm
UREA_exp_60 = np.array([0.02614146539, 0.0223338683, 0.0176443878, 0.0148875675, 0.0156582799])

# ---------------------------
# TEMPERATURE CONSTANTS
# ---------------------------
def temperature_constants(T):
    KNH3 = np.exp(191.97 - 8451.61/T - 31.4335*np.log(T) + 0.0152123*T)
    KCO2 = np.exp(2767.92 - 80063.5/T - 478.653*np.log(T) + 0.714984*T)
    KHCO3 = np.exp(12.405 - 6286.89/T - 0.050628*T)
    KH2O = np.exp(14.01708 - 10294.83/T - 0.039282*T)
    H_CO2 = np.exp(-(1082.37 - 34417.2/T - 182.28*np.log(T) + 0.25159*T))
    H_NH3 = np.exp(-(160.559 - 8621.06/T - 25.6767*np.log(T) + 0.035388*T))
    return KNH3, KCO2, KHCO3, KH2O, H_CO2, H_NH3

KNH3, KCO2, KHCO3, KH2O, H_CO2, H_NH3 = temperature_constants(T_fixed)

# ---------------------------
# SAFE pH SOLVER
# ---------------------------
def electroneutrality(pH, TA, TC):
    H = 10**(-pH)
    OH = KH2O / H
    NH4 = TA * (H*KNH3)/(KNH3*H + KH2O)
    denom = H**2 + KCO2*H + KCO2*KHCO3
    HCO3 = TC * (KCO2*H)/denom
    CO3 = TC * (KCO2*KHCO3)/denom
    return NH4 + H - (HCO3 + 2*CO3 + OH)

def solve_pH_safe(TA, TC):
    TA = max(TA, 1e-16)
    TC = max(TC, 1e-16)
    try:
        from scipy.optimize import brentq
        return brentq(electroneutrality, 2, 12, args=(TA, TC))
    except ValueError:
        return 7.0

# ---------------------------
# SIMULATION FUNCTION
# ---------------------------
def simulate_urea(kla_CO2, parameter_kinetic, initial_CO2):
    urease = 25 * 0.015 / 1000 * parameter_kinetic
    n_CO2_0 = initial_CO2 * V_gas / (R_gas * T_fixed)

    def model(t, y):
        UREA, TA, TC, n_NH3_gas, n_CO2_gas = y
        pH = solve_pH_safe(TA, TC)
        H = 10**(-pH)
        P_CO2 = n_CO2_gas * R_gas * T_fixed / V_gas
        P_NH3 = n_NH3_gas * R_gas * T_fixed / V_gas
        denom = H**2 + KCO2*H + KCO2*KHCO3
        CO2_aq = TC * H**2 / denom
        NH3_free = TA * KH2O / (KNH3*H + KH2O)
        rCO2 = kla_CO2 * (CO2_aq - H_CO2 * P_CO2)
        rNH3 = kla_NH3 * (NH3_free - H_NH3 * P_NH3)
        S = max(UREA, 1e-16)
        P = max(NH3_free, 1e-16)
        kT = np.exp(-Ea/R * (1/T_fixed - 1/T_star))
        ve = kT * urease * S / ((KM+S)*(1+P/KP)*(1+(10**(-pH))/KES1 + KES2/(10**(-pH))))
        dUREA = -ve
        dTA = 2*ve - rNH3
        dTC = ve - rCO2
        dn_NH3 = rNH3 * V_liq
        dn_CO2 = rCO2 * V_liq
        return [dUREA, dTA, dTC, dn_NH3, dn_CO2]

    y0 = [0.030, 0, 0, 0, n_CO2_0]
    sol = solve_ivp(model, (0, 60), y0, method="BDF", t_eval=[60], rtol=1e-6, atol=1e-9)
    return sol.y[0, -1]

# ---------------------------
# LEAST SQUARES OBJECTIVE FOR TWO PARAMETERS
# ---------------------------
def residuals_two_params(x):
    kla_CO2, parameter_kinetic = x
    res = []
    for i, P0 in enumerate(initial_CO2_exp):
        Urea_model = simulate_urea(kla_CO2, parameter_kinetic, P0)
        res.append(Urea_model - UREA_exp_60[i])
    return res

# ---------------------------
# FIT BOTH PARAMETERS
# ---------------------------
x0 = [0.01, 82.7]
bounds = ([1e-6, 1], [10, 500])

result = least_squares(residuals_two_params, x0=x0, bounds=bounds)
kla_CO2_fit, parameter_kinetic_fit = result.x

print("\nFITTED kla_CO2 =", kla_CO2_fit)
print("FITTED parameter_kinetic =", parameter_kinetic_fit)
print("Residuals =", result.fun)

# ---------------------------
# Compute R^2
# ---------------------------
residuals_fit = np.array(residuals_two_params(result.x))
SS_res = np.sum(residuals_fit**2)
SS_tot = np.sum((UREA_exp_60 - np.mean(UREA_exp_60))**2)
R2 = 1 - SS_res / SS_tot
print(f"R^2 = {R2:.4f}")

# ---------------------------
# Compute confidence intervals
# ---------------------------
n = len(UREA_exp_60)
p = 2  # number of parameters
J = result.jac
sigma2 = SS_res / (n - p)
cov = sigma2 * np.linalg.inv(J.T @ J)
se = np.sqrt(np.diag(cov))
ci95 = 1.96 * se
print(f"95% CI kla_CO2: {kla_CO2_fit:.4f} ± {ci95[0]:.4f}")
print(f"95% CI parameter_kinetic: {parameter_kinetic_fit:.4f} ± {ci95[1]:.4f}")

# ---------------------------
# Plot confidence ellipse
# ---------------------------
def plot_confidence_ellipse(mean, cov, ax, n_std=2.0, **kwargs):
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    vals, vecs = vals[order], vecs[:, order]
    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
    width, height = 2 * n_std * np.sqrt(vals)
    ellipse = Ellipse(xy=mean, width=width, height=height, angle=theta, **kwargs)
    ax.add_patch(ellipse)

fig, ax = plt.subplots(figsize=(6,5))
ax.scatter(kla_CO2_fit, parameter_kinetic_fit, color='red', label='Fitted')
plot_confidence_ellipse([kla_CO2_fit, parameter_kinetic_fit], cov, ax, n_std=1.96, edgecolor='blue', facecolor='none', lw=2, label='95% CI ellipse')
ax.set_xlabel("k$_\mathrm{L}$a / min$^{-1}$")
ax.set_ylabel("$f_\mathrm{ac}$ / -")
ax.legend()
plt.show()

# ---------------------------
# FINAL SIMULATION AND PLOTTING
# ---------------------------
plt.figure(figsize=(6,4))
time = np.linspace(0,60,300)

for i, P0 in enumerate(initial_CO2_exp):
    n_CO2_0 = P0 * V_gas / (R_gas * T_fixed)
    y0 = [0.030, 0, 0, 0, n_CO2_0]

    def model_final(t, y):
        urease = 25 * 0.015 / 1000 * parameter_kinetic_fit
        UREA, TA, TC, n_NH3_gas, n_CO2_gas = y
        pH = solve_pH_safe(TA, TC)
        H = 10**(-pH)
        P_CO2 = n_CO2_gas * R_gas * T_fixed / V_gas
        P_NH3 = n_NH3_gas * R_gas * T_fixed / V_gas
        denom = H**2 + KCO2*H + KCO2*KHCO3
        CO2_aq = TC * H**2 / denom
        NH3_free = TA * KH2O / (KNH3*H + KH2O)
        rCO2 = kla_CO2_fit * (CO2_aq - H_CO2 * P_CO2)
        rNH3 = kla_NH3 * (NH3_free - H_NH3 * P_NH3)
        S = max(UREA, 1e-16)
        P = max(NH3_free, 1e-16)
        kT = np.exp(-Ea/R * (1/T_fixed - 1/T_star))
        ve = kT * urease * S / ((KM+S)*(1+P/KP)*(1+(10**(-pH))/KES1 + KES2/(10**(-pH))))
        dUREA = -ve
        dTA = 2*ve - rNH3
        dTC = ve - rCO2
        dn_NH3 = rNH3 * V_liq
        dn_CO2 = rCO2 * V_liq
        return [dUREA, dTA, dTC, dn_NH3, dn_CO2]

    sol = solve_ivp(model_final, (0,60), y0, t_eval=time, method="BDF")
    plt.plot(time, sol.y[0], label=f"Initial CO2 presure {P0} bar")

    # Print final urea and pH
    UREA_final = sol.y[0, -1]
    TA_final = sol.y[1, -1]
    TC_final = sol.y[2, -1]
    pH_final = solve_pH_safe(TA_final, TC_final)
    print(f"Initial CO2 = {P0} atm -> Urea at 60 min: {UREA_final:.6f} mol/L, pH: {pH_final:.2f}")

# Plot experimental points
for i, P0 in enumerate(initial_CO2_exp):
    plt.scatter(60, UREA_exp_60[i], color='red')

plt.xlabel("t / min")
plt.ylabel("$c_U$ / mol L$^{-1}$")
plt.legend()
plt.show()