#!/usr/bin/env python
# coding: utf-8

# -----------------------------------------------------------------------------
# This script is licensed under the Creative Commons Attribution 4.0
# International License (CC BY 4.0).
#
# You are free to use, modify, and distribute this code for any purpose,
# including commercial applications, provided that proper credit is given.
#
# When using this script for scientific research or derivative works, please cite:
#
#   “3D Cellular Automata of Polymer Gel Response to Solvent Change”
#    Vasilii Korotenko, Irina Smirnova, Pavel Gurikov
#    Thermal Separation Processes, TUHH, Eissendorfer Str. 38,
#    21073 Hamburg, Germany
#    E-mail: pavel.gurikov@tuhh.de
#
# A link to the license:
# https://creativecommons.org/licenses/by/4.0/
#
# © Authors of the corresponding publication
# -----------------------------------------------------------------------------

import os
import re
import time
import numpy as np
import matplotlib.pyplot as plt
import csv

# ------------------- USER SETTINGS -------------------
base_dir = '.'           # Current directory (folder with the attempt)
overwrite = False        # Recalculate if the CSV file already exists
show_plot = True         # Whether to display plots
smooth_window = 10       # Smoothing window (0 = disable)

# ------------------- PARSER -------------------
def extract_params_from_folder(folder_name):
    """Извлекает параметры L, W, Epp, Ewp из имени папки."""
    params = {}
    for key in ['L', 'W', 'Epp', 'Ewp']:
        m = re.search(rf'{key}(-?\d+(?:\.\d+)?)', folder_name)
        if m:
            params[key] = float(m.group(1))
    return params

# ------------------- MAIN -------------------
results_all = {}
total_time = 0.0
recalculated = 0
skipped = 0

print("🔍 Scanning trajectories folders in:", base_dir)

for folder in sorted(os.listdir(base_dir)):
    if not folder.startswith('trajectories_'):
        continue

    params = extract_params_from_folder(folder)
    if not all(k in params for k in ['L', 'W', 'Epp', 'Ewp']):
        print(f"⛔ Skip (no params): {folder}")
        continue

    L, W, Epp, Ewp = int(params['L']), int(params['W']), params['Epp'], params['Ewp']
    folder_path = os.path.join(base_dir, folder)
    filepath = os.path.join(folder_path, 'trajectories.npz')
    csv_path = os.path.join(folder_path, 'Rg_vs_frame.csv')

    if not os.path.exists(filepath):
        print(f"⚠️ No trajectories.npz in {folder}")
        continue

    start_time = time.time()

    # --- Если CSV уже есть ---
    if os.path.exists(csv_path) and not overwrite:
        try:
            rg_list = []
            with open(csv_path, newline='') as f:
                reader = csv.reader(f)
                next(reader, None)  # skip header
                for row in reader:
                    if len(row) >= 2:
                        rg_list.append(float(row[1]))
            rg_list = np.array(rg_list)
            print(f"📄 Loaded existing CSV: {csv_path}")
            skipped += 1
        except Exception as e:
            print(f"❌ Error reading CSV, will recalc: {e}")
            overwrite = True  # fallback to recalculation

    # --- Иначе пересчитать ---
    if not os.path.exists(csv_path) or overwrite:
        try:
            data = np.load(filepath, allow_pickle=True)
            if 'lattice_snapshots' not in data:
                print(f"⚠️ No 'lattice_snapshots' in {folder}")
                continue

            snapshots = list(data['lattice_snapshots'])
            rg_list = []

            for lattice in snapshots:
                coords = np.argwhere(lattice >= 1000)
                if coords.size == 0:
                    rg_list.append(np.nan)
                    continue
                r_cm = coords.mean(axis=0)
                rg = np.sqrt(np.mean(np.sum((coords - r_cm) ** 2, axis=1)))
                rg_list.append(rg)

            rg_list = np.array(rg_list)
            if smooth_window > 1 and len(rg_list) > smooth_window:
                rg_list = np.convolve(rg_list, np.ones(smooth_window)/smooth_window, mode='valid')

            with open(csv_path, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(['frame', 'Rg'])
                for i, rg in enumerate(rg_list):
                    writer.writerow([i, rg])

            elapsed = time.time() - start_time
            total_time += elapsed
            recalculated += 1
            print(f"💾 Saved CSV: {csv_path}  ⏱️ {elapsed:.2f} s")

        except Exception as e:
            print(f"❌ Error in {folder}: {e}")
            continue

    results_all.setdefault((W, Epp, Ewp, L), []).append(np.array(rg_list))
    print(f"✅ L={L}, W={W}, Epp={Epp}, Ewp={Ewp}, len={len(rg_list)}")

# ------------------- AVERAGING -------------------
if not results_all:
    print("⚠️ No valid data found.")
    exit()

averaged = {}
for key, runs in results_all.items():
    min_len = min(len(r) for r in runs)
    aligned = np.array([r[:min_len] for r in runs])
    mean = np.nanmean(aligned, axis=0)
    std = np.nanstd(aligned, axis=0)
    averaged[key] = (mean, std)

# ------------------- PLOTTING -------------------
for (W, Epp, Ewp) in sorted(set(k[:3] for k in averaged.keys())):
    plt.figure(figsize=(9, 6))

    for L in sorted([k[3] for k in averaged.keys() if k[:3] == (W, Epp, Ewp)]):
        mean, std = averaged[(W, Epp, Ewp, L)]
        frames = np.arange(len(mean))
        plt.plot(frames, mean, lw=1.8, label=f"L={L}")
        plt.fill_between(frames, mean - std, mean + std, alpha=0.3)

    plt.xlabel("Frame")
    plt.ylabel("Radius of gyration (Rg)")
    plt.title(f"Average Rg(t)\n(W={W}, Epp={Epp}, Ewp={Ewp})")
    plt.legend(title="Chain length L")
    plt.grid(True)
    plt.tight_layout()

    outname = f"Rg_vs_frame_avg_W{W}_Epp{Epp}_Ewp{Ewp}.png"
    plt.savefig(outname, dpi=120)
    print(f"📈 Saved averaged plot: {outname}")

    if show_plot:
        plt.show()
    else:
        plt.close()

# ------------------- SUMMARY -------------------
print("\n✅ Averaging and plotting complete.")
print(f"\n📊 Summary:")
print(f"  Recalculated CSV files : {recalculated}")
print(f"  Skipped (already exist) : {skipped}")
print(f"  Total compute time      : {total_time:.2f} s")
