#!/usr/bin/env python
# coding: utf-8

# -----------------------------------------------------------------------------
# This script is licensed under the Creative Commons Attribution 4.0
# International License (CC BY 4.0).
#
# You are free to use, modify, and distribute this code for any purpose,
# including commercial applications, provided that proper credit is given.
#
# When using this script for scientific research or derivative works, please cite:
#
#   “3D Cellular Automata of Polymer Gel Response to Solvent Change”
#    Vasilii Korotenko, Irina Smirnova, Pavel Gurikov
#    Thermal Separation Processes, TUHH, Eissendorfer Str. 38,
#    21073 Hamburg, Germany
#    E-mail: pavel.gurikov@tuhh.de
#
# A link to the license:
# https://creativecommons.org/licenses/by/4.0/
#
# © Authors of the corresponding publication
# -----------------------------------------------------------------------------

import os
import re
import time
import numpy as np
import matplotlib.pyplot as plt
import csv

# ------------------- SETTINGS -------------------
base_dir = "."       # Рабочая директория (внутри одной попытки)
overwrite = False    # Если True — пересчитать Rg(t) и перезаписать CSV

# ------------------- HELPERS -------------------
def extract_params_from_folder(folder_name):
    """Извлекает L, W, Epp, Ewp, MAX_ITER из имени папки."""
    params = {}
    patterns = {
        "L": r"L(\d+)",
        "W": r"W(-?\d+(?:\.\d+)?)",
        "Epp": r"Epp(-?\d+(?:\.\d+)?)",
        "Ewp": r"Ewp(-?\d+(?:\.\d+)?)",
        "MAX_ITER": r"MAX_ITER(\d+)"
    }
    for key, pat in patterns.items():
        m = re.search(pat, folder_name)
        if m:
            params[key] = m.group(1)
    return params


def read_rg_series(csv_path):
    """Читает Rg(t) из CSV с заголовком [frame, Rg]."""
    frames, rg_vals = [], []
    with open(csv_path, newline="") as f:
        reader = csv.reader(f)
        next(reader, None)
        for row in reader:
            if len(row) >= 2:
                try:
                    frames.append(int(row[0]))
                    rg_vals.append(float(row[1]))
                except ValueError:
                    continue
    return np.array(frames), np.array(rg_vals)


def write_rg_series(csv_path, rg_list):
    """Сохраняет Rg(t) в CSV."""
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["frame", "Rg"])
        for i, val in enumerate(rg_list):
            w.writerow([i, val])


# ------------------- MAIN -------------------
results_all = {}
total_time = 0.0
recalculated = 0
skipped = 0

print("🔍 Scanning trajectories folders in:", base_dir)

for folder in sorted(os.listdir(base_dir)):
    if not folder.startswith("trajectories_"):
        continue

    params = extract_params_from_folder(folder)
    if not all(k in params for k in ["L", "W", "Epp", "Ewp"]):
        print(f"⛔ Skip (no params): {folder}")
        continue

    L = int(params["L"])
    W = float(params["W"])
    Epp = float(params["Epp"])
    Ewp = float(params["Ewp"])

    folder_path = os.path.join(base_dir, folder)
    npz_path = os.path.join(folder_path, "trajectories.npz")
    csv_path = os.path.join(folder_path, "Rg_vs_frame.csv")

    if not os.path.exists(npz_path):
        print(f"⚠️ No trajectories.npz in {folder}")
        continue

    start_time = time.time()

    # --- Используем кэшированный Rg(t), если есть ---
    use_cached = os.path.exists(csv_path) and not overwrite
    rg_series = None

    if use_cached:
        try:
            _, rg_series = read_rg_series(csv_path)
            print(f"📄 Loaded Rg_vs_frame.csv: {len(rg_series)} frames")
            skipped += 1
        except Exception as e:
            print(f"❌ Failed to read CSV ({e}), recalculating")
            use_cached = False

    if not use_cached:
        try:
            data = np.load(npz_path, allow_pickle=True)
            if "lattice_snapshots" not in data:
                print(f"⚠️ No 'lattice_snapshots' in {folder}")
                continue

            snapshots = list(data["lattice_snapshots"])
            rg_series = []

            for lattice in snapshots:
                coords = np.argwhere(lattice >= 1000)
                if coords.size == 0:
                    rg_series.append(np.nan)
                    continue
                r_cm = coords.mean(axis=0)
                rg = np.sqrt(np.mean(np.sum((coords - r_cm) ** 2, axis=1)))
                rg_series.append(rg)

            rg_series = np.array(rg_series)
            write_rg_series(csv_path, rg_series)
            recalculated += 1
            print(f"💾 Saved Rg_vs_frame.csv: {len(rg_series)} frames")

        except Exception as e:
            print(f"❌ Error in {folder}: {e}")
            continue

    elapsed = time.time() - start_time
    total_time += elapsed
    print(f"⏱️  Time for {folder}: {elapsed:.2f} s")

    # --- Усреднение последних 20% кадров ---
    valid = rg_series[~np.isnan(rg_series)]
    if len(valid) < 10:
        print(f"⚠️ Too few valid frames ({len(valid)}). Skipping.")
        continue

    n_tail = max(1, len(valid) // 5)
    tail = valid[-n_tail:]
    rg_avg = np.mean(tail)
    rg_std = np.std(tail)
    print(f"✅ Rg_avg = {rg_avg:.3f} ± {rg_std:.3f}")

    results_all.setdefault((W, Epp, Ewp, L), []).append((rg_avg, rg_std))


# ------------------- AVERAGING AND PLOTTING -------------------
if not results_all:
    print("⚠️ No Rg data found.")
    raise SystemExit(0)

averaged = {}
for key, vals in results_all.items():
    arr = np.array(vals)
    avg_vals = arr[:, 0]
    std_vals = arr[:, 1]
    mean_mean = np.mean(avg_vals)
    comb_std = np.sqrt(np.mean(std_vals**2) + np.std(avg_vals)**2)
    averaged[key] = (mean_mean, comb_std)

for (W, Epp, Ewp) in sorted(set(k[:3] for k in averaged.keys())):
    L_vals = np.array(sorted([k[3] for k in averaged if k[:3] == (W, Epp, Ewp)]))
    Rg_means = np.array([averaged[(W, Epp, Ewp, L)][0] for L in L_vals])
    Rg_stds = np.array([averaged[(W, Epp, Ewp, L)][1] for L in L_vals])

    coeffs, cov = np.polyfit(np.log(L_vals), np.log(Rg_means), 1, cov=True)
    nu, b = coeffs
    nu_err = np.sqrt(cov[0, 0])
    L_fit = np.linspace(L_vals.min(), L_vals.max(), 200)
    Rg_fit = np.exp(b) * L_fit**nu

    plt.figure(figsize=(6, 5))
    plt.errorbar(L_vals, Rg_means, yerr=Rg_stds, fmt="o-", capsize=5, label=f"Ewp={Ewp}")

    # --- Отдельные точки ---
    for L in L_vals:
        individual_values = [v[0] for v in results_all[(W, Epp, Ewp, L)]]
        jitter = (np.random.rand(len(individual_values)) - 0.5) * 0.1 * L
        plt.scatter(np.array([L] * len(individual_values)) + jitter,
                    individual_values, s=35, color="black", alpha=0.6, zorder=3)

    plt.loglog(L_fit, Rg_fit, "--", label=f"⟨Rg⟩ ∼ L^{nu:.2f} ± {nu_err:.2f}")
    plt.xlabel("L (chain length)")
    plt.ylabel("Average ⟨Rg⟩")
    plt.title(f"Scaling of ⟨Rg⟩ (W={W}, Epp={Epp}, Ewp={Ewp})")
    plt.legend()
    plt.grid(True, which="both")
    plt.tight_layout()

    outname = f"Rg_vs_L_avg_W{W}_Epp{Epp}_Ewp{Ewp}.png"
    plt.savefig(outname, dpi=150)
    print(f"📈 Saved: {outname}")

plt.show()

# ------------------- SUMMARY -------------------
print("\n✅ Averaged Rg analysis complete.")
print(f"Recalculated CSV files : {recalculated}")
print(f"Skipped (already exist): {skipped}")
print(f"Total compute time     : {total_time:.2f} s")
