import matplotlib.pyplot as plt
import numpy as np

# ---------------------------
# DATA
# ---------------------------

U_model = np.array([0.027954, 0.022963, 0.019129, 0.013156, 0.015996])
U_exp   = np.array([0.02614146539, 0.0223338683, 0.0176443878, 0.0148875675, 0.0156582799])

U_exp_err = np.array([0.70493295, 1.637198533, 1.373243502, 4.309374753, 1.987814535]) / 1000

# Pressure labels
pressures = ["0 bar", "0.5 bar", "1 bar", "2.1 bar", "3.5 bar"]

# ---------------------------
# STYLE
# ---------------------------

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 12,
    "axes.labelsize": 12,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.linewidth": 1,
})

fig, ax = plt.subplots(figsize=(4.8, 4.1))

# ---------------------------
# PARITY LINE
# ---------------------------

min_val = min(np.min(U_model), np.min(U_exp)) - 0.001
max_val = max(np.max(U_model), np.max(U_exp)) + 0.001

ax.plot(
    [min_val, max_val],
    [min_val, max_val],
    linestyle='--',
    color='black',
    linewidth=1
)

# ---------------------------
# FIRST TWO POINTS (steelblue)
# ---------------------------

ax.errorbar(
    U_model[:],
    U_exp[:],
    yerr=U_exp_err[:],
    fmt='o',
    markersize=6,
    markerfacecolor='white',
    markeredgecolor='black',
    capsize=4,
    linewidth=1
)



# ---------------------------
# PRESSURE LABELS
# ---------------------------

for i in range(len(U_model)):
    ax.annotate(
        pressures[i],
        (U_model[i], U_exp[i]),
        xytext=(-10, -15),
        textcoords="offset points",
        fontsize=10
    )

# ---------------------------
# LABELS
# ---------------------------

ax.set_xlabel("$C_\\mathrm{U}^\\mathrm{model}$(60 min) / mol L$^{-1}$")
ax.set_ylabel("$C_\\mathrm{U}^\\mathrm{exp}$(60 min) / mol L$^{-1}$")

# ---------------------------
# LIMITS AND TICKS
# ---------------------------

ax.set_xlim(min_val, max_val)
ax.set_ylim(min_val, max_val)

ticks = np.arange(0.01, 0.035, 0.005)
ax.set_xticks(ticks)
ax.set_yticks(ticks)

# ---------------------------
# FULL BOX
# ---------------------------

for spine in ax.spines.values():
    spine.set_visible(True)
    spine.set_linewidth(1.1)

# ---------------------------
# R²
# ---------------------------

SS_res = np.sum((U_exp - U_model)**2)
SS_tot = np.sum((U_exp - np.mean(U_exp))**2)
R2 = 1 - SS_res/SS_tot

print(f"U concentration R² = {R2:.4f}")

# ---------------------------
# SAVE
# ---------------------------

plt.tight_layout()
plt.savefig("parity_plot_U.pdf", format='pdf', bbox_inches='tight')
plt.show()