import numpy as np
from scipy.optimize import brentq
import matplotlib.pyplot as plt

# -------------------------------
# Constants
# -------------------------------
R = 8.31446261815324  # J/mol/K
T_ref = 298.15  # K (25°C)

# Tris buffer
tris_total = 0.5  # M
pKa_tris_ref = 8.07
dpKa_tris_dT = -0.028  # per °C

# Carbonic acid
pK1_ref = 6.35
pK2_ref = 10.33
K1_ref = 10 ** (-pK1_ref)
K2_ref = 10 ** (-pK2_ref)
dH_K1 = -9000.0  # J/mol
dH_K2 = -15000.0   # J/mol

# CO2 solubility
H_CO2_ref = 0.034       # mol/(L·atm)
dH_CO2 = -20000.0       # J/mol

# Water
Kw_ref = 1e-14
dH_w = 55700.0  # J/mol

# -------------------------------
# Functions
# -------------------------------
def K_vant_hoff(K_ref, dH, T, T_ref=298.15):
    return K_ref * np.exp(-dH / R * (1.0 / T - 1.0 / T_ref))

def pKw_vs_T(T, Kw_ref=Kw_ref, dH_w=dH_w, T_ref=298.15):
    return K_vant_hoff(Kw_ref, dH_w, T, T_ref)

def davies_activity_coeff(z, I):
    sqrtI = np.sqrt(I)
    log10_gamma = -0.5 * z * z * (sqrtI / (1.0 + sqrtI) - 0.3 * I)
    return 10 ** log10_gamma

# -------------------------------
# Speciation solver for loaded Tris (with CO2)
# -------------------------------
def solve_speciation(T_C=25.0, pCO2=0.15, tris_total=tris_total):
    T = T_C + 273.15
    # equilibrium constants
    H_CO2 = K_vant_hoff(H_CO2_ref, dH_CO2, T)
    K1 = K_vant_hoff(K1_ref, dH_K1, T)
    K2 = K_vant_hoff(K2_ref, dH_K2, T)
    Ka_tris = 10 ** (-(pKa_tris_ref + dpKa_tris_dT*(T_C - 25.0)))
    Kw = pKw_vs_T(T)

    CO2_aq = H_CO2 * pCO2
    T_tris = tris_total

    def charge_balance(log10_H, gammas):
        H = 10 ** log10_H
        gamma_H = gammas['H']
        gamma_HCO3 = gammas['HCO3']
        gamma_CO2 = gammas['CO2']
        gamma_CO3 = gammas['CO3']
        gamma_B = gammas['B']
        gamma_BH = gammas['BH']

        HCO3 = K1 * (gamma_CO2 / (gamma_H * gamma_HCO3)) * CO2_aq / H
        CO3 = K2 * (gamma_HCO3 / (gamma_H * gamma_CO3)) * HCO3 / H
        r = Ka_tris * (gamma_BH / (gamma_B * gamma_H)) / H
        BH = T_tris / (1.0 + r)
        OH = Kw / H

        return (H + BH) - (HCO3 + 2*CO3 + OH)

    I = 0.01
    log10_H_guess = -7.0
    for it in range(500):
        gammas = {
            'H': davies_activity_coeff(+1, I),
            'OH': davies_activity_coeff(-1, I),
            'HCO3': davies_activity_coeff(-1, I),
            'CO3': davies_activity_coeff(-2, I),
            'CO2': 1.0,
            'B': 1.0,
            'BH': davies_activity_coeff(+1, I)
        }
        log10_H_new = brentq(lambda logH: charge_balance(logH, gammas), -14, 0, maxiter=500, xtol=1e-12)

        H = 10 ** log10_H_new
        HCO3 = K1 * (gammas['CO2'] / (gammas['H'] * gammas['HCO3'])) * CO2_aq / H
        CO3 = K2 * (gammas['HCO3'] / (gammas['H'] * gammas['CO3'])) * HCO3 / H
        r = Ka_tris * (gammas['BH'] / (gammas['B'] * gammas['H'])) / H
        BH = T_tris / (1.0 + r)
        B = T_tris - BH
        OH = Kw / H

        I_new = 0.5*(H**2 + OH**2 + BH**2 + HCO3**2 + 4*CO3**2)
        if abs(I_new - I) < 1e-8 and abs(log10_H_new - log10_H_guess) < 1e-8:
            I = I_new
            log10_H_guess = log10_H_new
            break
        I = 0.5*(I + I_new)
        log10_H_guess = 0.5*(log10_H_guess + log10_H_new)

    return {
        'T_K': T,
        'T_C': T_C,
        'pCO2': pCO2,
        'CO2_aq': CO2_aq,
        'H_plus': H,
        'pH': -np.log10(H),
        'OH_minus': OH,
        'HCO3_minus': HCO3,
        'CO3_2minus': CO3,
        'Tris_B': B,
        'Tris_BH_plus': BH,
        'Ionic_strength': I,
        'pKa_tris': pKa_tris_ref + dpKa_tris_dT*(T_C - 25.0)
    }



# -------------------------------
# Sweep temperature 10-90°C
# -------------------------------
temps_C = np.linspace(10, 90, 17)

CO2_list, HCO3_list, CO3_list, DIC_list, pH_loaded_list, pKa_list = [], [], [], [], [], []
pH_unloaded_list = []

for T in temps_C:
    # Loaded Tris
    res = solve_speciation(T_C=T, pCO2=0.15)
    CO2_list.append(res['CO2_aq'])
    HCO3_list.append(res['HCO3_minus'])
    CO3_list.append(res['CO3_2minus'])
    DIC_list.append(res['CO2_aq'] + res['HCO3_minus'] + res['CO3_2minus'])
    pH_loaded_list.append(res['pH'])
    pKa_list.append(res['pKa_tris'])
    


# -------------------------------
# Plot CO2 speciation vs T
# -------------------------------
plt.figure(figsize=(4.3, 4))
plt.plot(temps_C, CO2_list, '--', linewidth=2, color='orange', label='CO$_2$(aq)')
plt.plot(temps_C, HCO3_list, ':', linewidth=2, color='teal', label='HCO$_3^-$')
plt.plot(temps_C, CO3_list, '-.', linewidth=2, color='purple', label='CO$_3^{2-}$')
plt.plot(temps_C, DIC_list, '-', linewidth=2, color='maroon', label='DIC')
plt.xlabel('$T$ / °C', fontsize=14)
plt.ylabel('[C] / M', fontsize=14)
plt.ylim(0,0.5)
plt.xticks(np.arange(0, 101, 20), fontsize=12)
plt.yticks(np.arange(0, 0.501, 0.1), fontsize=12)
plt.yticks(fontsize=12)
plt.grid(False)
plt.legend(fontsize=12)
plt.tight_layout()

plt.savefig('Fig3.pdf', format='pdf')


plt.show()

# -------------------------------
# CO2 loadings vs temperature
# -------------------------------
pCO2_list = [0.15, 0.125, 0.10, 0.075, 0.05]
cmap = plt.get_cmap('plasma')
colors = [cmap(i/(len(pCO2_list)-1)) for i in range(len(pCO2_list))]

plt.figure(figsize=(4.3, 4))
for pCO2, color in zip(pCO2_list, colors):
    loading = []
    for T in temps_C:
        res = solve_speciation(T_C=T, pCO2=pCO2)
        DIC = res['CO2_aq'] + res['HCO3_minus'] + res['CO3_2minus']
        loading.append(DIC / tris_total)
    plt.plot(temps_C, loading, color=color, linewidth=2, label=f'{pCO2:.3f} mol/mol')

plt.xlabel('$T$ / °C', fontsize=14)
plt.ylabel('$L_{\mathrm{CO_2}}$ / mol/mol', fontsize=14)
plt.xticks(np.arange(0, 101, 20), fontsize=12)
plt.yticks(fontsize=12)
plt.ylim(0, 1)
plt.grid(False)
plt.legend(fontsize=12)
plt.tight_layout()

plt.savefig('Fig4.pdf', format='pdf')

plt.show()





plt.tight_layout()


plt.show()
