import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import brentq
import matplotlib.pyplot as plt

# ---------------------------
# CONSTANTS
# ---------------------------

R = 8.314        # J/mol/K
R_gas = 0.08206  # L·atm/(mol·K)

# Temperature
T_fixed = 308.15  # K

# ---------------------------
# PAPER PARAMETERS (CABEQ)
# ---------------------------

Ea = 35300      # J/mol
T_star = 414.6  # K

KM = 3.21e-3    # mol/L
KES1 = 7.57e-7  # mol/L
KES2 = 1.27e-8  # mol/L
KP = 1.22e-2    # mol/L


parameter_kinetic= 40.04
# enzyme concentration (must be proportional to g/L)
urease = 25 * 0.015  / 1000 * parameter_kinetic

# ---------------------------
# REACTOR / MASS TRANSFER
# ---------------------------

V_liq = 0.015  # L
V_gas = 0.005  # L

kla_CO2 = 0.0092
kla_NH3 = 0.0092

# ---------------------------
# TEMPERATURE DEPENDENT CONSTANTS
# ---------------------------

def temperature_constants(T):

    KNH3 = np.exp(191.97 - 8451.61/T - 31.4335*np.log(T) + 0.0152123*T)
    KCO2 = np.exp(2767.92 - 80063.5/T - 478.653*np.log(T) + 0.714984*T)
    KHCO3 = np.exp(12.405 - 6286.89/T - 0.050628*T)
    KH2O = np.exp(14.01708 - 10294.83/T - 0.039282*T)

    H_CO2 = np.exp(-(1082.37 - 34417.2/T - 182.28*np.log(T) + 0.25159*T))
    H_NH3 = np.exp(-(160.559 - 8621.06/T - 25.6767*np.log(T) + 0.035388*T))

    return KNH3, KCO2, KHCO3, KH2O, H_CO2, H_NH3

KNH3, KCO2, KHCO3, KH2O, H_CO2, H_NH3 = temperature_constants(T_fixed)

# ---------------------------
# pH SOLVER
# ---------------------------

def electroneutrality(pH, TA, TC):

    H = 10**(-pH)
    OH = KH2O / H

    NH4 = TA * (H*KNH3)/(KNH3*H + KH2O)
    HCO3 = TC * (KCO2*H)/(H**2 + KCO2*H + KCO2*KHCO3)
    CO3 = TC * (KCO2*KHCO3)/(H**2 + KCO2*H + KCO2*KHCO3)

    return NH4 + H - (HCO3 + 2*CO3 + OH)


def solve_pH(TA, TC):

    TA = max(TA, 1e-16)
    TC = max(TC, 1e-16)

    return brentq(
        electroneutrality,
        2,
        12,
        args=(TA, TC)
    )

# ---------------------------
# MODEL
# ---------------------------

def model(t, y):

    UREA, TA, TC, n_NH3_gas, n_CO2_gas = y

    # Solve pH
    pH = solve_pH(TA, TC)
    H = 10**(-pH)

    # Gas pressures
    P_CO2 = n_CO2_gas * R_gas * T_fixed / V_gas
    P_NH3 = n_NH3_gas * R_gas * T_fixed / V_gas

    # Liquid species
    CO2_aq = TC * H**2 / (H**2 + KCO2*H + KCO2*KHCO3)
    NH3_free = TA * KH2O / (KNH3*H + KH2O)

    # Mass transfer
    rCO2 = kla_CO2 * (CO2_aq - H_CO2 * P_CO2)
    rNH3 = kla_NH3 * (NH3_free - H_NH3 * P_NH3)

    # ---------------------------
    # PAPER RATE EXPRESSION
    # ---------------------------

    S = max(UREA, 1e-16)
    P = max(NH3_free, 1e-16)

    kT = np.exp(
        -Ea/R * (1/T_fixed - 1/T_star)
    )

    den_substrate = KM + S
    den_product = 1 + P / KP
    den_pH = 1 + (10**(-pH))/KES1 + KES2/(10**(-pH))

    ve = (
        kT
        * urease
        * S
        / (den_substrate * den_product * den_pH)
    )

    # ---------------------------
    # MASS BALANCES
    # ---------------------------

    dUREA = -ve
    dTA = 2*ve - rNH3
    dTC = ve - rCO2

    dn_NH3 = rNH3 * V_liq
    dn_CO2 = rCO2 * V_liq

    return [dUREA, dTA, dTC, dn_NH3, dn_CO2]

# ---------------------------
# INITIAL CONDITIONS
# ---------------------------

y0 = [
    0.015,     # urea mol/L
    0.00,       # total ammonia
    0.0,       # total carbon
    0.0,       # NH3 gas mol
    4.504545e-04/1.09    # CO2 gas mol
]

t_span = (0, 60)
t_eval = np.linspace(0, 60, 300)

# ---------------------------
# SOLVE
# ---------------------------

sol = solve_ivp(
    model,
    t_span,
    y0,
    method="BDF",
    t_eval=t_eval,
    rtol=1e-6,
    atol=1e-9
)

# ---------------------------
# POSTPROCESS
# ---------------------------

time = sol.t
UREA, TA, TC, n_NH3_gas, n_CO2_gas = sol.y

pH = []
P_CO2 = []
P_NH3 = []

for i in range(len(time)):

    pH_i = solve_pH(TA[i], TC[i])
    pH.append(pH_i)

    P_CO2.append(
        n_CO2_gas[i]*R_gas*T_fixed/V_gas
    )

    P_NH3.append(
        n_NH3_gas[i]*R_gas*T_fixed/V_gas
    )

# ---------------------------
# PLOTS
# ---------------------------



fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 7), sharex=True)




# ---------------------------
# Plot data
# ---------------------------
ax2.plot(time, UREA, label='U', color='green')
ax2.plot(time, TC, label='DIC', color='red')
ax2.plot(time, TA, label='TAN', color='black')
ax2.set_ylabel('$c_{\mathrm{i}}$ / mol L$^{-1}$', fontsize=14)
ax2.legend(frameon=False, loc='best')
ax2.tick_params(direction="out", length=4, width=1.2, labelsize=12)

ax2b = ax2.twinx()
ax2b.plot(time, pH, color='blue', linestyle='--', label='pH / -')
ax2b.set_ylabel('pH / -', fontsize=14, color='blue')
ax2b.tick_params(axis='y', colors='blue')
ax2b.tick_params(direction="out", length=4, width=1.2, labelsize=12)

ax1.plot(time, P_CO2, label='CO$_{2}$', color='orange')
#ax1.plot(time, P_NH3, label='NH$_{3}$', color='purple')
ax1.set_ylabel('$p_\mathrm{i}$ / bar', fontsize=14)
ax1.legend(frameon=False, loc='best')
ax2.set_xlabel('$t$ / min', fontsize=14)
ax1.tick_params(direction="out", length=4, width=1.2, labelsize=12)



# first point (green marker, black error bar)
ax2.errorbar(
    60,              # x
    0.00485,           # y
    yerr=0.00081,
    fmt='o',
    color='green',   # marker color
    ecolor='black',  # error bar color
    capsize=4,
    markersize=6
)



ax2.set_xticks(np.arange(0, 61, 10))
ax2.set_yticks(np.arange(0, 0.040001, 0.01))
ax2b.set_yticks(np.arange(6, 8.01, 0.5))

ax1.set_yticks(np.arange(0.0, 4.01, 1.0))




# ---------------------------
# Force all four spines to appear
# ---------------------------
for ax in [ax1, ax2]:
    for spine in ax.spines.values():
        spine.set_visible(True)      # ensure visible
        spine.set_linewidth(1.2)    # thickness
        spine.set_color('black')     # color

# Optional: style twin y-axis separately
for spine in ax2b.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1.2)
    spine.set_color('black')




plt.tight_layout()
plt.savefig("sim2.pdf", format="pdf", bbox_inches="tight")

plt.show()

print(
    f"Final urea concentration: "
    f"{UREA[-1]:.6f} mol/L"
)

print(
    f"Final pH: "
    f"{pH[-1]:.6f} mol/L"
)