import matplotlib.pyplot as plt
import numpy as np

# ---------------------------
# DATA
# ---------------------------

pH_model = np.array([8.90, 8.31, 8.04, 7.50, 6.64])
pH_exp   = np.array([9.046666667, 7.946666667, 7.803333333, 7.096666667, 6.923333333])

pH_exp_err = np.array([0.032145503, 0.083864971, 0.096090235, 0.223681321, 0.087368949])

# Pressure labels
pressures = ["0 bar", "0.5 bar", "1 bar", "2.1 bar", "3.5 bar"]

# ---------------------------
# STYLE
# ---------------------------

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 12,
    "axes.labelsize": 12,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.linewidth": 1,
})

fig, ax = plt.subplots(figsize=(4.4, 3.9))

# ---------------------------
# PARITY LINE
# ---------------------------

min_val = min(np.min(pH_model), np.min(pH_exp)) - 0.2
max_val = max(np.max(pH_model), np.max(pH_exp)) + 0.2

ax.plot(
    [min_val, max_val],
    [min_val, max_val],
    linestyle='--',
    color='black',
    linewidth=1
)

# ---------------------------
# DATA POINTS
# ---------------------------

ax.errorbar(
    pH_model,
    pH_exp,
    yerr=pH_exp_err,
    fmt='o',
    markersize=6,
    markerfacecolor='white',
    markeredgecolor='black',
    capsize=4,
    linewidth=1
)

# ---------------------------
# PRESSURE LABELS
# ---------------------------

for i in range(len(pH_model)):
    ax.annotate(
        pressures[i],
        (pH_model[i], pH_exp[i]),
        xytext=(-10, -15),
        textcoords="offset points",
        fontsize=10
    )

# ---------------------------
# LABELS
# ---------------------------

ax.set_xlabel("pH$^\\mathrm{model}$ / -")
ax.set_ylabel("pH$^\\mathrm{exp}$ / -")

# ---------------------------
# LIMITS AND TICKS
# ---------------------------

ax.set_xlim(min_val, max_val)
ax.set_ylim(min_val, max_val)

ticks = np.arange(6, 9.51, 0.5)
ax.set_xticks(ticks)
ax.set_yticks(ticks)

# ---------------------------
# FULL BOX
# ---------------------------

for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1.1)

# ---------------------------
# R²
# ---------------------------

SS_res = np.sum((pH_exp - pH_model)**2)
SS_tot = np.sum((pH_exp - np.mean(pH_exp))**2)
R2 = 1 - SS_res/SS_tot

print(f"pH R² = {R2:.4f}")

# ---------------------------
# SAVE
# ---------------------------

plt.tight_layout()
plt.savefig("parity_plot_pH.pdf", format='pdf', bbox_inches='tight')
plt.show()