import numpy as np
from scipy.optimize import brentq
import matplotlib.pyplot as plt

# -------------------------------
# Constants
# -------------------------------
R = 8.31446261815324  # J/mol/K
T_ref = 298.15  # K (25°C)

# Tris buffer
tris_total = 0.5  # M
pKa_tris_ref = 8.07
dpKa_tris_dT = -0.028  # per °C

# Carbonic acid
pK1_ref = 6.35
pK2_ref = 10.33
K1_ref = 10 ** (-pK1_ref)
K2_ref = 10 ** (-pK2_ref)
dH_K1 = -9000.0  # J/mol
dH_K2 = -15000.0   # J/mol

# CO2 solubility
H_CO2_ref = 0.034       # mol/(L·atm)
dH_CO2 = -20000.0       # J/mol

# Water
Kw_ref = 1e-14
dH_w = 55700.0  # J/mol

# -------------------------------
# Functions
# -------------------------------
def K_vant_hoff(K_ref, dH, T, T_ref=298.15):
    return K_ref * np.exp(-dH / R * (1.0 / T - 1.0 / T_ref))

def pKw_vs_T(T, Kw_ref=Kw_ref, dH_w=dH_w, T_ref=298.15):
    return K_vant_hoff(Kw_ref, dH_w, T, T_ref)

def davies_activity_coeff(z, I):
    sqrtI = np.sqrt(I)
    log10_gamma = -0.5 * z * z * (sqrtI / (1.0 + sqrtI) - 0.3 * I)
    return 10 ** log10_gamma

# -------------------------------
# Speciation solver for loaded Tris (with CO2)
# -------------------------------
def solve_speciation(T_C=25.0, pCO2=0.15, tris_total=tris_total):
    T = T_C + 273.15
    # equilibrium constants
    H_CO2 = K_vant_hoff(H_CO2_ref, dH_CO2, T)
    K1 = K_vant_hoff(K1_ref, dH_K1, T)
    K2 = K_vant_hoff(K2_ref, dH_K2, T)
    Ka_tris = 10 ** (-(pKa_tris_ref + dpKa_tris_dT*(T_C - 25.0)))
    Kw = pKw_vs_T(T)

    CO2_aq = H_CO2 * pCO2
    T_tris = tris_total

    def charge_balance(log10_H, gammas):
        H = 10 ** log10_H
        gamma_H = gammas['H']
        gamma_HCO3 = gammas['HCO3']
        gamma_CO2 = gammas['CO2']
        gamma_CO3 = gammas['CO3']
        gamma_B = gammas['B']
        gamma_BH = gammas['BH']

        HCO3 = K1 * (gamma_CO2 / (gamma_H * gamma_HCO3)) * CO2_aq / H
        CO3 = K2 * (gamma_HCO3 / (gamma_H * gamma_CO3)) * HCO3 / H
        r = Ka_tris * (gamma_BH / (gamma_B * gamma_H)) / H
        BH = T_tris / (1.0 + r)
        OH = Kw / H

        return (H + BH) - (HCO3 + 2*CO3 + OH)

    I = 0.01
    log10_H_guess = -7.0
    for it in range(500):
        gammas = {
            'H': davies_activity_coeff(+1, I),
            'OH': davies_activity_coeff(-1, I),
            'HCO3': davies_activity_coeff(-1, I),
            'CO3': davies_activity_coeff(-2, I),
            'CO2': 1.0,
            'B': 1.0,
            'BH': davies_activity_coeff(+1, I)
        }
        log10_H_new = brentq(lambda logH: charge_balance(logH, gammas), -14, 0, maxiter=500, xtol=1e-12)

        H = 10 ** log10_H_new
        HCO3 = K1 * (gammas['CO2'] / (gammas['H'] * gammas['HCO3'])) * CO2_aq / H
        CO3 = K2 * (gammas['HCO3'] / (gammas['H'] * gammas['CO3'])) * HCO3 / H
        r = Ka_tris * (gammas['BH'] / (gammas['B'] * gammas['H'])) / H
        BH = T_tris / (1.0 + r)
        B = T_tris - BH
        OH = Kw / H

        I_new = 0.5*(H**2 + OH**2 + BH**2 + HCO3**2 + 4*CO3**2)
        if abs(I_new - I) < 1e-8 and abs(log10_H_new - log10_H_guess) < 1e-8:
            I = I_new
            log10_H_guess = log10_H_new
            break
        I = 0.5*(I + I_new)
        log10_H_guess = 0.5*(log10_H_guess + log10_H_new)

    return {
        'T_K': T,
        'T_C': T_C,
        'pCO2': pCO2,
        'CO2_aq': CO2_aq,
        'H_plus': H,
        'pH': -np.log10(H),
        'OH_minus': OH,
        'HCO3_minus': HCO3,
        'CO3_2minus': CO3,
        'Tris_B': B,
        'Tris_BH_plus': BH,
        'Ionic_strength': I,
        'pKa_tris': pKa_tris_ref + dpKa_tris_dT*(T_C - 25.0)
    }

# -------------------------------
# Unloaded Tris pH (no CO2)
# -------------------------------
def pH_unloaded_tris(T_C, tris_total=tris_total):
    pKa_T = pKa_tris_ref + dpKa_tris_dT*(T_C - 25.0)
    Ka = 10 ** (-pKa_T)
    Kw = Kw_ref  # Approximate, can use pKw_vs_T(T_C + 273.15)
    # Solve H + BH = OH + B for H (approximation)
    def f(logH):
        H = 10**logH
        r = Ka / H
        BH = tris_total / (1 + 1/r)
        B = tris_total - BH
        OH = Kw / H
        return (H + BH) - (OH + B)
    logH_solution = brentq(f, -12, -4)
    return -np.log10(10**logH_solution)

# -------------------------------
# Parameter sweep ranges
# -------------------------------
pKa_vals = np.linspace(7.0, 11, 60)        # pKa at 20°C
dpKa_vals = np.linspace(-0.035, -0.015, 60)   # dpKa/dT (per °C)

P, D = np.meshgrid(pKa_vals, dpKa_vals)
delta_loading = np.zeros_like(P)

pCO2_fixed = 0.15
A_tot = tris_total

# -------------------------------
# Modified solver wrapper
# -------------------------------
def loading_at_T(T_C, pKa_ref, dpKa_dT):
    global pKa_tris_ref, dpKa_tris_dT
    pKa_tris_ref = pKa_ref
    dpKa_tris_dT = dpKa_dT
    
    res = solve_speciation(T_C=T_C, pCO2=pCO2_fixed, tris_total=A_tot)
    DIC = res['CO2_aq'] + res['HCO3_minus'] + res['CO3_2minus']
    return DIC / A_tot

# -------------------------------
# Sweep grid
# -------------------------------
for i in range(P.shape[0]):
    for j in range(P.shape[1]):
        pKa_test = P[i, j]
        dpKa_test = D[i, j]
        
        L20 = loading_at_T(20.0, pKa_test, dpKa_test)
        L60 = loading_at_T(60.0, pKa_test, dpKa_test)
        
        delta_loading[i, j] = L20 - L60

# -------------------------------
# Contour plot
# -------------------------------
plt.figure(figsize=(5,3.8))

cont = plt.contourf(P, D, delta_loading, levels=25, cmap='coolwarm')
plt.colorbar(cont, label=r'$\Delta L_{\mathrm{CO_2}}$ (20°C − 60°C) / mol/mol')

plt.xlabel(r'$pK_a$(25°C)', fontsize=12)
plt.ylabel(r'$\mathrm{d}pK_a/\mathrm{d}T$ (per °C)', fontsize=12)

plt.xticks(np.arange(7, 11.01, 1), fontsize=12)
plt.yticks(np.arange(-0.035, -0.01, 0.005), fontsize=12)


# -------------------------------
# Points to overlay
# -------------------------------
points = [
    (8.07, -0.025, "AHPD"),
    (8.5,  -0.018, "MDEA"),
    (9.694, -0.028, "AMP"),
    (8.79, -0.026, "AMPD"),
    (7.77, -0.019, "TEA"),
]

for x, y, label in points:
    plt.plot(x, y, 'o', color='black', markersize=3)  # circle marker
    plt.text(x, y + 0.001, label, ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('Figconta2.pdf')
plt.show()

