import numpy as np
from scipy.optimize import brentq
import matplotlib.pyplot as plt

# -------------------------------
# Constants
# -------------------------------
R = 8.31446261815324  # J/mol/K
T_ref = 298.15  # K (25°C)

# Tris buffer
tris_total = 0.5  # M
pKa_tris_ref = 8.07
dpKa_tris_dT = -0.028  # per °C

# Carbonic acid
pK1_ref = 6.35
pK2_ref = 10.33
K1_ref = 10 ** (-pK1_ref)
K2_ref = 10 ** (-pK2_ref)
dH_K1 = -8300.0  # J/mol
dH_K2 = 2400.0   # J/mol

# CO2 solubility
H_CO2_ref = 0.034       # mol/(L·atm)
dH_CO2 = -19300.0       # J/mol

# Water
Kw_ref = 1e-14
dH_w = 55700.0  # J/mol

# -------------------------------
# Functions
# -------------------------------
def K_vant_hoff(K_ref, dH, T, T_ref=298.15):
    return K_ref * np.exp(-dH / R * (1.0 / T - 1.0 / T_ref))

def pKw_vs_T(T, Kw_ref=Kw_ref, dH_w=dH_w, T_ref=298.15):
    return K_vant_hoff(Kw_ref, dH_w, T, T_ref)

# -------------------------------
# Speciation solver for ideal liquid (all activity coefficients = 1)
# -------------------------------
def solve_speciation_ideal(T_C=25.0, pCO2=0.15, tris_total=tris_total):
    T = T_C + 273.15
    # temperature-corrected equilibrium constants
    H_CO2 = K_vant_hoff(H_CO2_ref, dH_CO2, T)
    K1 = K_vant_hoff(K1_ref, dH_K1, T)
    K2 = K_vant_hoff(K2_ref, dH_K2, T)
    Ka_tris = 10 ** (-(pKa_tris_ref + dpKa_tris_dT * (T_C - 25.0)))
    Kw = pKw_vs_T(T)

    CO2_aq = H_CO2 * pCO2
    T_tris = tris_total

    # Charge balance: [H+] + [BH+] = [HCO3-] + 2[CO3^2-] + [OH-]
    def charge_balance(log10_H):
        H = 10 ** log10_H
        HCO3 = K1 * CO2_aq / H          # ideal: activities = concentrations
        CO3 = K2 * HCO3 / H
        # Correct closed-form for BH:
        BH = T_tris * H / (Ka_tris + H)   # <<--- CORRECT expression
        OH = Kw / H
        return (H + BH) - (HCO3 + 2 * CO3 + OH)

    # solve for H (log10 scale) using brentq
    log10_H_solution = brentq(charge_balance, -14, 0, maxiter=500, xtol=1e-12)
    H = 10 ** log10_H_solution

    HCO3 = K1 * CO2_aq / H
    CO3 = K2 * HCO3 / H
    BH = T_tris * H / (Ka_tris + H)   # ensure consistent with the solver
    B = T_tris - BH
    OH = Kw / H
    DIC = CO2_aq + HCO3 + CO3
    ionic_strength = 0.5 * (H ** 2 + OH ** 2 + BH ** 2 + HCO3 ** 2 + 4 * CO3 ** 2)

    return {
        'T_K': T,
        'T_C': T_C,
        'pCO2': pCO2,
        'CO2_aq': CO2_aq,
        'H_plus': H,
        'pH': -np.log10(H),
        'OH_minus': OH,
        'HCO3_minus': HCO3,
        'CO3_2minus': CO3,
        'Tris_B': B,
        'Tris_BH_plus': BH,
        'Ionic_strength': ionic_strength,
        'pKa_tris': pKa_tris_ref + dpKa_tris_dT * (T_C - 25.0),
        'DIC': DIC
    }


# -------------------------------
# (rest: sweep & plotting - same as your original)
# -------------------------------
def main():
    temps_C = np.linspace(10, 90, 17)

    CO2_list, HCO3_list, CO3_list, DIC_list = [], [], [], []
    pH_loaded_list, pKa_list = [], []
    pH_unloaded_list = []

    for T in temps_C:
        res = solve_speciation_ideal(T_C=T, pCO2=0.15)
        CO2_list.append(res['CO2_aq'])
        HCO3_list.append(res['HCO3_minus'])
        CO3_list.append(res['CO3_2minus'])
        DIC_list.append(res['DIC'])
        pH_loaded_list.append(res['pH'])
        pKa_list.append(res['pKa_tris'])

    # plotting (unchanged)
    plt.figure(figsize=(4.3, 4))
    plt.plot(temps_C, CO2_list, '--', linewidth=2, label='CO$_2$(aq)')
    plt.plot(temps_C, HCO3_list, ':', linewidth=2, label='HCO$_3^-$')
    plt.plot(temps_C, CO3_list, '-.', linewidth=2, label='CO$_3^{2-}$')
    plt.plot(temps_C, DIC_list, '-', linewidth=2, label='DIC')
    plt.xlabel('$T$ / °C', fontsize=14)
    plt.ylabel('C / M', fontsize=14)
    plt.xticks(np.arange(0, 101, 20), fontsize=12)
    plt.yticks(fontsize=12)
    plt.ylim(0, 0.400001)
    plt.grid(False)
    plt.legend(fontsize=12)
    plt.yticks(np.arange(0, 0.4000001, 0.1), fontsize=12)
    plt.tight_layout()
    plt.savefig('Fig3_ideal_corrected.pdf', format='pdf')
    plt.show()

    # loadings
    
    pCO2_list = [0.15, 0.125, 0.10, 0.075, 0.05]
    cmap = plt.get_cmap('plasma')
    colors = [cmap(i / (len(pCO2_list) - 1)) for i in range(len(pCO2_list))]
    
    pCO2_list = [0.15, 0.125, 0.10, 0.075, 0.05]
    plt.figure(figsize=(4.3, 4))
    for pCO2, color in zip(pCO2_list, colors):
        loading = []
        for T in temps_C:
            res = solve_speciation_ideal(T_C=T, pCO2=pCO2)
            loading.append(res['DIC'] / tris_total)
        plt.plot(temps_C, loading, color=color, linewidth=2, label=f'{pCO2:.2f} mol/mol')
    plt.xlabel('$T$ / °C', fontsize=14)
    plt.ylabel('$C_{CO_2,load}$ / mol/mol', fontsize=14)
    plt.xticks(np.arange(0, 101, 20), fontsize=12)
    plt.yticks(fontsize=12)
    plt.ylim(0, 1.00012)
    plt.xticks(np.arange(0, 101, 20), fontsize=12)
    plt.yticks(np.arange(0, 1.00012, 0.2), fontsize=12)
    plt.grid(False)
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig('Fig4_ideal_corrected.pdf', format='pdf')
    plt.show()





    plt.tight_layout()
    plt.savefig('Fig5_ideal_corrected.pdf', format='pdf')
    plt.show()

if __name__ == '__main__':
    main()
