import numpy as np
from scipy.optimize import brentq
import matplotlib.pyplot as plt

# -------------------------------
# Constants
# -------------------------------
R = 8.31446261815324  # J/mol/K
T_ref = 298.15  # K (25°C)

# Tris buffer
tris_total = 0.5  # M
pKa_tris_ref = 8.07
dpKa_tris_dT = -0.028  # per °C

# Carbonic acid
pK1_ref = 6.35
pK2_ref = 10.33
K1_ref = 10 ** (-pK1_ref)
K2_ref = 10 ** (-pK2_ref)
dH_K1 = -8300.0  # J/mol
dH_K2 = 2400.0   # J/mol

# CO2 solubility
H_CO2_ref = 0.034       # mol/(L·atm)
dH_CO2 = -19300.0       # J/mol

# Water
Kw_ref = 1e-14
dH_w = 55700.0  # J/mol

# -------------------------------
# Functions
# -------------------------------
def K_vant_hoff(K_ref, dH, T, T_ref=298.15):
    return K_ref * np.exp(-dH / R * (1.0 / T - 1.0 / T_ref))


def pKw_vs_T(T, Kw_ref=Kw_ref, dH_w=dH_w, T_ref=298.15):
    return K_vant_hoff(Kw_ref, dH_w, T, T_ref)


# -------------------------------
# Extended Debye-Huckel activity coefficient
# -------------------------------

def edh_activity_coeff(z, I, a_nm=0.9, A=0.509, B=0.3285):
    if I <= 0:
        return 1.0
    sqrtI = np.sqrt(I)
    a = a_nm
    log10_gamma = -A * (z ** 2) * sqrtI / (1.0 + B * a * sqrtI)
    return 10 ** log10_gamma


# -------------------------------
# Speciation solver using EDH
# -------------------------------

def solve_speciation_edh(T_C=25.0, pCO2=0.15, tris_total=tris_total):
    T = T_C + 273.15
    # equilibrium constants (temperature-corrected)
    H_CO2 = K_vant_hoff(H_CO2_ref, dH_CO2, T)
    K1 = K_vant_hoff(K1_ref, dH_K1, T)
    K2 = K_vant_hoff(K2_ref, dH_K2, T)
    Ka_tris = 10 ** (-(pKa_tris_ref + dpKa_tris_dT * (T_C - 25.0)))
    Kw = pKw_vs_T(T)

    CO2_aq = H_CO2 * pCO2
    T_tris = tris_total

    # iterate to self-consistent ionic strength & gammas (fixed-point)
    I = 0.01
    log10_H_guess = -7.0
    for it in range(500):
        gammas = {
            'H': edh_activity_coeff(+1, I, a_nm=0.9),
            'OH': edh_activity_coeff(-1, I, a_nm=0.9),
            'HCO3': edh_activity_coeff(-1, I, a_nm=0.9),
            'CO3': edh_activity_coeff(-2, I, a_nm=1.2),  # larger effective size
            'CO2': 1.0,
            'B': 1.0,
            'BH': edh_activity_coeff(+1, I, a_nm=0.9)
        }

        def charge_balance(log10_H):
            H = 10 ** log10_H
            gamma_H = gammas['H']
            gamma_HCO3 = gammas['HCO3']
            gamma_CO2 = gammas['CO2']
            gamma_CO3 = gammas['CO3']
            gamma_B = gammas['B']
            gamma_BH = gammas['BH']

            HCO3 = K1 * (gamma_CO2 / (gamma_H * gamma_HCO3)) * CO2_aq / H
            CO3 = K2 * (gamma_HCO3 / (gamma_H * gamma_CO3)) * HCO3 / H
            r = Ka_tris * (gamma_BH / (gamma_B * gamma_H)) / H
            BH = T_tris / (1.0 + r)
            OH = Kw / H

            return (H + BH) - (HCO3 + 2 * CO3 + OH)

        # solve for H
        log10_H_new = brentq(lambda logH: charge_balance(logH), -14, 0, maxiter=500, xtol=1e-12)

        H = 10 ** log10_H_new
        HCO3 = K1 * (gammas['CO2'] / (gammas['H'] * gammas['HCO3'])) * CO2_aq / H
        CO3 = K2 * (gammas['HCO3'] / (gammas['H'] * gammas['CO3'])) * HCO3 / H
        r = Ka_tris * (gammas['BH'] / (gammas['B'] * gammas['H'])) / H
        BH = T_tris / (1.0 + r)
        B = T_tris - BH
        OH = Kw / H

        I_new = 0.5 * (H ** 2 + OH ** 2 + BH ** 2 + HCO3 ** 2 + 4 * CO3 ** 2)

        if abs(I_new - I) < 1e-8 and abs(log10_H_new - log10_H_guess) < 1e-8:
            I = I_new
            log10_H_guess = log10_H_new
            break

        I = 0.5 * (I + I_new)
        log10_H_guess = 0.5 * (log10_H_guess + log10_H_new)

    return {
        'T_K': T,
        'T_C': T_C,
        'pCO2': pCO2,
        'CO2_aq': CO2_aq,
        'H_plus': H,
        'pH': -np.log10(H),
        'OH_minus': OH,
        'HCO3_minus': HCO3,
        'CO3_2minus': CO3,
        'Tris_B': B,
        'Tris_BH_plus': BH,
        'Ionic_strength': I,
        'pKa_tris': pKa_tris_ref + dpKa_tris_dT * (T_C - 25.0)
    }




# -------------------------------
# Main: sweep temperature and plot
# -------------------------------

def main():
    temps_C = np.linspace(10, 90, 17)

    CO2_list, HCO3_list, CO3_list, DIC_list = [], [], [], []
    pH_loaded_list, pKa_list = [], []
    pH_unloaded_list = []

    for T in temps_C:
        res = solve_speciation_edh(T_C=T, pCO2=0.15)
        CO2_list.append(res['CO2_aq'])
        HCO3_list.append(res['HCO3_minus'])
        CO3_list.append(res['CO3_2minus'])
        DIC_list.append(res['CO2_aq'] + res['HCO3_minus'] + res['CO3_2minus'])
        pH_loaded_list.append(res['pH'])
        pKa_list.append(res['pKa_tris'])

    # -------------------------------
    # Plot CO2 speciation vs T (EDH)
    # -------------------------------
    plt.figure(figsize=(4.3, 4))
    plt.plot(temps_C, CO2_list, '--', linewidth=2, label='CO$_2$(aq)')
    plt.plot(temps_C, HCO3_list, ':', linewidth=2, label='HCO$_3^-$')
    plt.plot(temps_C, CO3_list, '-.', linewidth=2, label='CO$_3^{2-}$')
    plt.plot(temps_C, DIC_list, '-', linewidth=2, label='DIC')
    plt.xlabel('$T$ / °C', fontsize=14)
    plt.ylabel('C / M', fontsize=14)
    plt.xticks(np.arange(0, 101, 20), fontsize=12)
    plt.ylim(0, 0.5)
    plt.yticks(fontsize=12)
    plt.yticks(np.arange(0, 0.5000001, 0.1), fontsize=12)
    plt.grid(False)
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig('Fig3_edh.pdf', format='pdf')
    plt.show()

    # -------------------------------
    # CO2 loadings vs temperature (EDH)
    # -------------------------------
    pCO2_list = [0.15, 0.125, 0.10, 0.075, 0.05]
    cmap = plt.get_cmap('plasma')
    colors = [cmap(i / (len(pCO2_list) - 1)) for i in range(len(pCO2_list))]

    plt.figure(figsize=(4.3, 4))
    for pCO2, color in zip(pCO2_list, colors):
        loading = []
        for T in temps_C:
            res = solve_speciation_edh(T_C=T, pCO2=pCO2)
            DIC = res['CO2_aq'] + res['HCO3_minus'] + res['CO3_2minus']
            loading.append(DIC / tris_total)
        plt.plot(temps_C, loading, color=color, linewidth=2, label=f'{pCO2:.2f} mol/mol')

    plt.xlabel('$T$ / °C', fontsize=14)
    plt.ylabel('$C_{CO_2,load}$ / mol/mol', fontsize=14)
    plt.xticks(np.arange(0, 101, 20), fontsize=12)
    plt.ylim(0, 1.00012)
    plt.yticks(fontsize=12)
    plt.yticks(np.arange(0, 1.00012, 0.2), fontsize=12)
    plt.grid(False)
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig('Fig4_edh.pdf', format='pdf')
    plt.show()




if __name__ == '__main__':
    main()
