#!/usr/bin/env python
# coding: utf-8

import os
import numpy as np
import pyvista as pv
import networkx as nx
import psutil
import time
from datetime import timedelta

# ---------------------------
# SETTINGS
# ---------------------------
OVERWRITE = False   # True = всегда перезаписывать видео, False = пропускать если уже есть

# ---------------------------
# Helpers
# ---------------------------

def load_gel_graph_npz(filename):
    data = np.load(filename, allow_pickle=True)
    gel_graph = nx.Graph()
    for node, coord in zip(data['node_ids'], data['node_coords']):
        gel_graph.add_node(int(node), coord=tuple(coord))
    for (u, v), t in zip(data['edges'], data['edge_types']):
        gel_graph.add_edge(int(u), int(v), type=str(t))
    return gel_graph


def read_params_from_files(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    final_file = os.path.join(folder, "final.npz")

    data = np.load(traj_file, allow_pickle=True)
    lattice_snapshots = data["lattice_snapshots"]
    num_frames = len(lattice_snapshots)

    print_interval = 1
    total_steps = None
    if os.path.exists(final_file):
        try:
            fin = np.load(final_file, allow_pickle=True)
            if "print_interval" in fin:
                print_interval = int(fin["print_interval"])
            if "steps" in fin:
                total_steps = int(fin["steps"])
        except Exception:
            pass

    if total_steps is None:
        total_steps = (num_frames - 1) * print_interval

    return print_interval, total_steps


# -----------------------------------
# 3D animation export (polymer only)
# -----------------------------------
def export_lattice_3d_animation_polymer(lattice_snapshots, gel_graph_snapshots, shared_graph,
                                        filename, SNAP, STEPS):

    snapshots = [snap for snap in lattice_snapshots if isinstance(snap, np.ndarray) and snap.ndim == 3]
    if not snapshots:
        print(f"[{filename}] ❌ No valid 3D snapshots found.")
        return

    N, M, L = snapshots[0].shape
    plotter = pv.Plotter(off_screen=True, window_size=(1000, 1000))
    plotter.open_movie(filename, framerate=10)

    print(f"[{filename}] 🎥 Rendering polymer only...")

    box = pv.Box(bounds=(-0.5, N - 0.5, -0.5, M - 0.5, -0.5, L - 0.5))

    for i, lattice in enumerate(snapshots):
        plotter.clear()

        # --- выбрать граф для этого шага ---
        graph = None
        if gel_graph_snapshots is not None and len(gel_graph_snapshots) > i:
            try:
                graph = gel_graph_snapshots[i].item()
            except Exception:
                graph = None
        elif shared_graph is not None:
            graph = shared_graph

        polymer_ids = set(graph.nodes) if graph is not None else set()

        # --- маска полимера ---
        mask_polymer = np.isin(lattice, list(polymer_ids))

        if np.any(mask_polymer):
            coords_poly = np.argwhere(mask_polymer).astype(np.float32)
            cloud_poly = pv.PolyData(coords_poly)
            plotter.add_mesh(
                cloud_poly,
                render_points_as_spheres=True,
                point_size=10,
                color="black",
                opacity=1.0
            )

            # --- ковалентные связи ---
            if graph is not None:
                vals_poly = lattice[mask_polymer].astype(int).ravel()
                id_to_idx = {int(node_id): idx for idx, node_id in enumerate(vals_poly)}
                lines_list = []
                for u, v, data in graph.edges(data=True):
                    if data.get("type", "covalent") != "covalent":
                        continue
                    iu = id_to_idx.get(int(u))
                    iv = id_to_idx.get(int(v))
                    if iu is None or iv is None:
                        continue
                    lines_list.extend([2, iu, iv])
                if lines_list:
                    lines = np.asarray(lines_list, dtype=np.int64)
                    bond_mesh = pv.PolyData()
                    bond_mesh.points = coords_poly
                    bond_mesh.lines = lines
                    plotter.add_mesh(bond_mesh, color="black", line_width=4, opacity=1.0)

        # --- оформление ---
        plotter.add_mesh(box, color="black", style="wireframe", line_width=1.0, opacity=0.35)

        plotter.reset_camera()
        plotter.camera_position = [
            (max(N, M, L) * 3.0, M * 0.5, L * 0.5),
            (N * 0.5, M * 0.5, L * 0.5),
            (0, 0, 1),
        ]

        plotter.add_axes()
        plotter.add_text(f"Timestep {i * SNAP} / {STEPS}", font_size=14)
        plotter.write_frame()

    plotter.close()
    print(f"[{filename}] ✅ Saved polymer-only video.")


# ------------------------
# Resource monitoring
# ------------------------

def monitor_resource_usage(func, *args, **kwargs):
    process = psutil.Process(os.getpid())
    cpu_samples = []
    mem_samples = []
    start_time = time.time()

    # мониторинг в фоне
    import threading
    stop_flag = False

    def sampler():
        while not stop_flag:
            try:
                cpu_samples.append(process.cpu_percent(interval=None))
                mem_samples.append(process.memory_info().rss / (1024 ** 2))  # MB
            except psutil.NoSuchProcess:
                break
            time.sleep(1)

    thread = threading.Thread(target=sampler, daemon=True)
    thread.start()

    func(*args, **kwargs)  # выполняем задачу

    stop_flag = True
    thread.join()

    elapsed = time.time() - start_time
    avg_cpu = np.mean(cpu_samples) if cpu_samples else 0
    max_cpu = np.max(cpu_samples) if cpu_samples else 0
    avg_mem = np.mean(mem_samples) if mem_samples else 0
    max_mem = np.max(mem_samples) if mem_samples else 0

    return elapsed, avg_cpu, max_cpu, avg_mem, max_mem


# ------------------------
# Sequential processing
# ------------------------

def process_directory(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    output_file = os.path.join(folder, "trajectories_all_polymer.mp4")

    if not os.path.exists(traj_file):
        print(f"[{folder}] ⚠️ trajectories.npz not found.")
        return None

    if os.path.exists(output_file) and not OVERWRITE:
        print(f"[{folder}] ⏩ Polymer video already exists, skipping.")
        return None

    print(f"[{folder}] 🔁 Rendering polymer-only video...")

    try:
        SNAP, STEPS = read_params_from_files(folder)
        data = np.load(traj_file, allow_pickle=True)
        lattice_snapshots = data["lattice_snapshots"]

        gel_graph_snapshots = data["gel_graph_snapshots"] if "gel_graph_snapshots" in data else None
        shared_graph = None

        if gel_graph_snapshots is None:
            graph_path = os.path.join(folder, "gel_graph_data.npz")
            if os.path.exists(graph_path):
                gdata = np.load(graph_path, allow_pickle=True)
                if "gel_graph" in gdata:
                    try:
                        shared_graph = gdata["gel_graph"].item()
                        print(f"[{graph_path}] ℹ️ Using saved gel_graph object.")
                    except Exception:
                        shared_graph = None
                elif {"node_ids", "node_coords", "edges", "edge_types"}.issubset(gdata.files):
                    shared_graph = load_gel_graph_npz(graph_path)
                    print(f"[{graph_path}] ℹ️ Reconstructed gel_graph from raw arrays.")
                else:
                    print(f"[{graph_path}] ⚠️ No valid graph keys found.")

        # мониторинг и выполнение
        elapsed, avg_cpu, max_cpu, avg_mem, max_mem = monitor_resource_usage(
            export_lattice_3d_animation_polymer,
            lattice_snapshots,
            gel_graph_snapshots,
            shared_graph,
            output_file,
            SNAP,
            STEPS
        )

        print(f"[{folder}] ⏱ Render time: {str(timedelta(seconds=int(elapsed)))}")
        print(f"[{folder}] 🧠 Memory avg={avg_mem:.1f}MB, max={max_mem:.1f}MB | CPU avg={avg_cpu:.1f}%, max={max_cpu:.1f}%")
        return (elapsed, avg_cpu, max_cpu, avg_mem, max_mem)

    except Exception as e:
        print(f"[{folder}] ❌ Error: {e}")
        return None


# ---------------
# Entry point
# ---------------

if __name__ == "__main__":
    base_dir = os.getcwd()
    subdirs = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d)) and os.path.exists(os.path.join(base_dir, d, "trajectories.npz"))
    ]

    print(f"📂 Found {len(subdirs)} folders with trajectories.npz")
    results = []

    start_global = time.time()
    for folder in subdirs:
        res = process_directory(folder)
        if res:
            results.append(res)

    if results:
        elapsed_all = time.time() - start_global
        arr = np.array(results)
        print("\n==============================")
        print(f"🎬 All videos rendered in {timedelta(seconds=int(elapsed_all))}")
        print(f"Average CPU: {arr[:,1].mean():.1f}% | Max CPU: {arr[:,2].max():.1f}%")
        print(f"Average RAM: {arr[:,3].mean():.1f}MB | Max RAM: {arr[:,4].max():.1f}MB")
        print("==============================")
    else:
        print("⚠️ No videos were rendered.")
