import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import fsolve

# Set temperature and constants
T = 308.15  # temperature in Kelvin (40 °C)
R = 0.082057  # L atm / mol K

# Calculate equilibrium constants at temperature T
KNH3 = np.exp(191.97 - 8451.61/T - 31.4335*np.log(T) + 0.0152123*T)
KCO2 = np.exp(2767.92 - 80063.5/T - 478.653*np.log(T) + 0.714984*T)
KHCO3 = np.exp(12.405 - 6286.89/T - 0.050628*T)
KH2O = np.exp(14.01708 - 10294.83/T -0.039282*T)

# Fixed background ion concentrations (adjust as needed)
K_conc = 0  # mol/L (potassium)
Cl_conc = 0  # mol/L (chloride)

# Define the electroneutrality function
def electroneutrality(H3O, TAN, TIC):
    OH = KH2O / H3O

    # Ionization fractions
    a_NH3 = KH2O / (KH2O + KNH3*H3O)
    a_NH4 = H3O*KNH3 / (KNH3*H3O + KH2O)

    a_CO2 = H3O**2 / (H3O**2 + KCO2*H3O + KCO2*KHCO3)
    a_HCO3 = KCO2*H3O / (H3O**2 + KCO2*H3O + KCO2*KHCO3)
    a_CO3 = KCO2*KHCO3 / (H3O**2 + KCO2*H3O + KCO2*KHCO3)

    # Species concentrations
    NH4 = a_NH4 * TAN
    HCO3 = a_HCO3 * TIC
    CO3 = a_CO3 * TIC

    # Electrically charged species: NH4+, H3O+, K+, minus: HCO3-, 2*CO3^2-, OH-, Cl-
    Z = NH4 + H3O + K_conc - (HCO3 + 2*CO3 + OH + Cl_conc)
    return Z

# Ranges of TIC and TAN to explore
TIC_values = np.linspace(0.000001, 0.06, 50)  # mol/L
TAN_values = np.linspace(0.000001, 0.06, 50)  # mol/L

# Create meshgrid
TIC_grid, TAN_grid = np.meshgrid(TIC_values, TAN_values)
pH_grid = np.zeros_like(TIC_grid)

# Loop over the grid and calculate pH
for i in range(TIC_grid.shape[0]):
    for j in range(TIC_grid.shape[1]):
        TIC = TIC_grid[i, j]
        TAN = TAN_grid[i, j]
        try:
            # Solve electroneutrality for H3O+ concentration
            H3O_solution, = fsolve(electroneutrality, 1e-11, args=(TAN, TIC))
            if H3O_solution > 0:
                pH_grid[i, j] = -np.log10(H3O_solution)
            else:
                pH_grid[i, j] = np.nan
        except:
            pH_grid[i, j] = np.nan


levels = np.linspace(4, 11.2, 25)

# Plot contour plot
plt.figure(figsize=(8*0.7,6*0.7))
cp = plt.contourf(TIC_grid, TAN_grid, pH_grid, levels=levels, cmap='Blues', vmin=4, vmax=11.2)
cbar = plt.colorbar(cp)




# Determine range for TIC
TIC_min, TIC_max = TIC_values[0], TIC_values[-1]
TIC_line = np.linspace(TIC_min, TIC_max, 200)

# Define vertical spacing between lines
offsets = np.linspace(-0.148, 0.03, 8)  # 5 evenly spaced offsets

for offset in offsets:
    TAN_line = 2 * TIC_line + offset
    # Mask values outside the plot range
    TAN_line_masked = np.where((TAN_line >= TAN_values[0]) & (TAN_line <= TAN_values[-1]), TAN_line, np.nan)
    plt.plot(TIC_line, TAN_line_masked, color='black', linestyle='--', linewidth=0.6)




# Add contour lines at pH = 7 and pH = 7.4
contour_levels = [6.8, 7.2]
contours = plt.contour(TIC_grid, TAN_grid, pH_grid, levels=contour_levels, colors='black', linewidths=0.6)
plt.clabel(contours, fmt='%1.1f', colors='black', fontsize=10)  # label lines with their pH values





# Get current axes
ax = plt.gca()

# Make all four spines visible (full box)
for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1)
    spine.set_color('black')

plt.xlabel(r'$c_\mathrm{DIC}$ / mol L$^{-1}$', fontsize=14)
plt.ylabel(r'$c_\mathrm{TAN}$ / mol L$^{-1}$', fontsize=14)
cbar.set_label('pH / -', fontsize=14)


# Find max TIC that keeps TAN=2*TIC <= max TAN value
max_TIC_for_line = TAN_values[-1] / 2
# Create TIC values up to that point
TIC_line = np.linspace(TIC_values[0], max_TIC_for_line, 100)
TAN_line = 2 * TIC_line



#plt.plot(TIC_line, TAN_line, color='white', linestyle='--', linewidth=2, label='TAN = 2×TIC')

cbar.ax.tick_params(labelsize=12)
plt.xticks([0, 0.01, 0.020,  0.03, 0.040, 0.05, 0.06], fontsize=12)
plt.yticks([0, 0.01, 0.020,  0.03, 0.040, 0.05, 0.06], fontsize=12)




# Save the current figure as a PDF
plt.savefig("ph_contour_plot.pdf", format='pdf', bbox_inches='tight')

plt.show()
