#!/usr/bin/env python
# coding: utf-8

# -----------------------------------------------------------------------------
# This script is licensed under the Creative Commons Attribution 4.0
# International License (CC BY 4.0).
#
# You are free to use, modify, and distribute this code for any purpose,
# including commercial applications, provided that proper credit is given.
#
# When using this script for scientific research or derivative works, please cite:
#
#   “3D Cellular Automata of Polymer Gel Response to Solvent Change”
#    Vasilii Korotenko, Irina Smirnova, Pavel Gurikov
#    Thermal Separation Processes, TUHH, Eissendorfer Str. 38,
#    21073 Hamburg, Germany
#    E-mail: pavel.gurikov@tuhh.de
#
# A link to the license:
# https://creativecommons.org/licenses/by/4.0/
#
# © Authors of the corresponding publication
# -----------------------------------------------------------------------------

import os
import re
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import correlate
from scipy.optimize import curve_fit
import csv

# ------------------- SETTINGS -------------------
base_dir = "."              # Рабочая директория (внутри одной попытки)
fit_window = 200            # Окно для аппроксимации ACF
overwrite = False           # Пересчитывать, если tau CSV уже есть

# ------------------- HELPERS -------------------
def extract_params_from_folder(folder_name):
    """Извлекает L, W, Epp, Ewp, MAX_ITER из имени папки."""
    params = {}
    patterns = {
        "L": r"L(\d+)",
        "W": r"W(-?\d+(?:\.\d+)?)",
        "Epp": r"Epp(-?\d+(?:\.\d+)?)",
        "Ewp": r"Ewp(-?\d+(?:\.\d+)?)",
        "MAX_ITER": r"MAX_ITER(\d+)"
    }
    for key, pat in patterns.items():
        m = re.search(pat, folder_name)
        if m:
            params[key] = m.group(1)
    return params

def exp_decay(t, tau):
    return np.exp(-t / tau)

def normalized_acf(x):
    """Нормализованная автокорреляционная функция."""
    x = x - np.mean(x)
    acf = correlate(x, x, mode="full")
    acf = acf[acf.size // 2:]
    return acf / acf[0] if acf[0] != 0 else acf

def compute_rg_series(snapshots):
    """Вычисляет Rg(t) для атомов полимера (ID >= 1000)."""
    rg_series = []
    for lattice in snapshots:
        coords = np.argwhere(lattice >= 1000)
        if coords.size == 0:
            continue
        r_cm = coords.mean(axis=0)
        rg = np.sqrt(np.mean(np.sum((coords - r_cm) ** 2, axis=1)))
        rg_series.append(rg)
    return np.array(rg_series)

def read_rg_series(csv_path):
    """Читает Rg(t) из CSV."""
    rg_vals = []
    with open(csv_path, newline="") as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if len(row) >= 2:
                try:
                    rg_vals.append(float(row[1]))
                except ValueError:
                    continue
    return np.array(rg_vals)

def write_rg_series(csv_path, rg_list):
    """Сохраняет Rg(t) в CSV."""
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["frame", "Rg"])
        for i, val in enumerate(rg_list):
            w.writerow([i, val])

def fit_tau_from_series(rg_series, fit_window):
    """Аппроксимирует экспоненциальное затухание автокорреляции и возвращает τ ± ошибку."""
    acf = normalized_acf(rg_series)
    t = np.arange(len(acf))
    fw = min(fit_window, len(acf))
    popt, pcov = curve_fit(exp_decay, t[:fw], acf[:fw], p0=(max(5, fw // 2)))
    tau = float(popt[0])
    tau_err = float(np.sqrt(np.diag(pcov))[0]) if pcov.size > 0 else np.nan
    return tau, tau_err, len(rg_series)

def read_tau_csv(csv_path):
    """Читает tau из CSV."""
    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        row = next(reader, None)
        if row is None:
            raise ValueError("Empty CSV")
        return {
            "tau": float(row["tau"]),
            "tau_err": float(row.get("tau_err", "nan")),
            "n_frames": int(row.get("n_frames", "0")),
            "fit_window": int(row.get("fit_window", "0"))
        }

def write_tau_csv(csv_path, meta):
    """Сохраняет tau и метаданные в CSV."""
    fields = ["tau", "tau_err", "n_frames", "fit_window", "L", "W", "Epp", "Ewp", "traj_folder"]
    with open(csv_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        w.writerow({k: meta.get(k, "") for k in fields})

# ------------------- MAIN -------------------
results_all = {}  # {(W,Epp,Ewp,L): [tau]}
total_time = 0.0
recalculated = 0
skipped = 0

print("🔍 Scanning trajectories folders in:", base_dir)

for folder in sorted(os.listdir(base_dir)):
    if not folder.startswith("trajectories_"):
        continue

    params = extract_params_from_folder(folder)
    if not all(k in params for k in ["L", "W", "Epp", "Ewp"]):
        print(f"⛔ Skip (no params): {folder}")
        continue

    L = int(params["L"])
    W = float(params["W"])
    Epp = float(params["Epp"])
    Ewp = float(params["Ewp"])

    folder_path = os.path.join(base_dir, folder)
    npz_path = os.path.join(folder_path, "trajectories.npz")
    rg_csv_path = os.path.join(folder_path, "Rg_vs_frame.csv")
    tau_csv_path = os.path.join(folder_path, "tau_relaxation.csv")

    if not os.path.exists(npz_path):
        print(f"⚠️ Missing trajectories.npz in {folder}")
        continue

    start_time = time.time()

    # Step 1: Rg(t)
    if os.path.exists(rg_csv_path) and not overwrite:
        try:
            rg_series = read_rg_series(rg_csv_path)
            print(f"📄 Loaded Rg_vs_frame.csv: {len(rg_series)} frames")
        except Exception as e:
            print(f"❌ Failed to read Rg CSV ({e}), recalculating")
            rg_series = None
    else:
        try:
            data = np.load(npz_path, allow_pickle=True)
            if "lattice_snapshots" not in data:
                print(f"⚠️ No lattice_snapshots in {folder}")
                continue
            snapshots = list(data["lattice_snapshots"])
            rg_series = compute_rg_series(snapshots)
            write_rg_series(rg_csv_path, rg_series)
            print(f"💾 Saved Rg_vs_frame.csv ({len(rg_series)} frames)")
        except Exception as e:
            print(f"❌ Error computing Rg(t): {e}")
            continue

    if rg_series is None or len(rg_series) < 10:
        print(f"⚠️ Too few valid frames ({len(rg_series) if rg_series is not None else 0}). Skipping.")
        continue

    # Step 2: tau
    use_cached = os.path.exists(tau_csv_path) and not overwrite
    tau = None
    tau_err = np.nan
    n_frames = len(rg_series)

    if use_cached:
        try:
            cached = read_tau_csv(tau_csv_path)
            tau = cached["tau"]
            tau_err = cached["tau_err"]
            skipped += 1
            print(f"📄 Loaded tau={tau:.2f} ± {tau_err:.2f} from {tau_csv_path}")
        except Exception as e:
            print(f"❌ Failed to read tau CSV ({e}), recalculating")
            use_cached = False

    if not use_cached:
        try:
            tau, tau_err, _ = fit_tau_from_series(rg_series, fit_window)
            write_tau_csv(tau_csv_path, {
                "tau": tau,
                "tau_err": tau_err,
                "n_frames": n_frames,
                "fit_window": fit_window,
                "L": L, "W": W, "Epp": Epp, "Ewp": Ewp,
                "traj_folder": folder
            })
            recalculated += 1
            print(f"💾 Saved tau={tau:.2f} ± {tau_err:.2f} to {tau_csv_path}")
        except Exception as e:
            print(f"❌ Error fitting tau in {folder}: {e}")
            continue

    elapsed = time.time() - start_time
    total_time += elapsed
    print(f"⏱️ Time for {folder}: {elapsed:.2f} s")

    results_all.setdefault((W, Epp, Ewp, L), []).append(tau)

# ------------------- AVERAGING -------------------
if not results_all:
    print("⚠️ No tau values found.")
    raise SystemExit(0)

averaged = {}
for key, tau_list in results_all.items():
    tau_arr = np.array(tau_list, dtype=float)
    mean_tau = np.nanmean(tau_arr)
    std_tau = np.nanstd(tau_arr)
    averaged[key] = (mean_tau, std_tau)

# ------------------- PLOTTING -------------------
for (W, Epp, Ewp) in sorted(set(k[:3] for k in averaged.keys())):
    plt.figure(figsize=(6, 5))

    L_vals = sorted([k[3] for k in averaged.keys() if k[:3] == (W, Epp, Ewp)])
    tau_means = np.array([averaged[(W, Epp, Ewp, L)][0] for L in L_vals])
    tau_stds = np.array([averaged[(W, Epp, Ewp, L)][1] for L in L_vals])

    plt.errorbar(L_vals, tau_means, yerr=tau_stds, fmt="o-", capsize=3, label=f"Ewp={Ewp}")

    logL, logTau = np.log(L_vals), np.log(tau_means)
    coeffs, cov = np.polyfit(logL, logTau, 1, cov=True)
    z, b = coeffs
    z_err = float(np.sqrt(cov[0, 0]))

    L_fit = np.linspace(L_vals[0], L_vals[-1], 200)
    tau_fit = np.exp(b) * L_fit**z
    plt.loglog(L_fit, tau_fit, "--", label=f"tau ~ L^{z:.2f} ± {z_err:.2f}")

    plt.xlabel("L (chain length)")
    plt.ylabel("Relaxation time tau")
    plt.title(f"Relaxation scaling W={W} Epp={Epp} Ewp={Ewp}")
    plt.legend()
    plt.grid(True, which="both")
    plt.tight_layout()

    outname = f"tau_vs_L_avg_W{W}_Epp{Epp}_Ewp{Ewp}.png"
    plt.savefig(outname, dpi=150)
    print(f"📈 Saved plot: {outname}")

plt.show()

# ------------------- SUMMARY -------------------
print("\n✅ Averaged relaxation analysis complete.")
print("\n📊 Summary:")
print(f"  Recalculated tau files : {recalculated}")
print(f"  Skipped already exist  : {skipped}")
print(f"  Total compute time     : {total_time:.2f} s")
