#!/usr/bin/env python
# coding: utf-8

import os
import time
import psutil
import numpy as np
import pyvista as pv
import networkx as nx
from datetime import timedelta

# ---------------------------
# SETTINGS
# ---------------------------
OVERWRITE = True
WINDOW_SIZE = (1000, 1000)
POINT_SIZE_POLYMER = 10
POINT_SIZE_SOLVENT = 6
CMAP = "coolwarm"
SOLVENT_OPACITY = 0.35
SHOW_SOLVENT = False  # <<< toggle this if you ever want to see solvent again

# ---------------------------
# Helpers
# ---------------------------

def load_gel_graph_npz(filename):
    data = np.load(filename, allow_pickle=True)
    gel_graph = nx.Graph()
    for node, coord in zip(data['node_ids'], data['node_coords']):
        gel_graph.add_node(int(node), coord=tuple(coord))
    for (u, v), t in zip(data['edges'], data['edge_types']):
        gel_graph.add_edge(int(u), int(v), type=str(t))
    return gel_graph


def read_params_from_files(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    final_file = os.path.join(folder, "final.npz")

    data = np.load(traj_file, allow_pickle=True)
    lattice_snapshots = data["lattice_snapshots"]
    num_frames = len(lattice_snapshots)

    print_interval = 1
    total_steps = None

    if os.path.exists(final_file):
        try:
            fin = np.load(final_file, allow_pickle=True)
            if "print_interval" in fin:
                print_interval = int(fin["print_interval"])
            if "total_steps" in fin:
                total_steps = int(fin["total_steps"])
        except Exception:
            pass

    if total_steps is None:
        total_steps = (num_frames - 1) * print_interval

    return print_interval, total_steps


def get_graph_for_index(i, gel_graph_snapshots, shared_graph):
    graph = None
    if gel_graph_snapshots is not None and len(gel_graph_snapshots) > i:
        try:
            graph = gel_graph_snapshots[i].item()
        except Exception:
            graph = None
    if graph is None:
        graph = shared_graph
    return graph


# ---------------------------
# Rendering
# ---------------------------

def render_frame_to_png(lattice, graph, png_path, idx, SNAP, STEPS):
    if not isinstance(lattice, np.ndarray) or lattice.ndim != 3:
        print(f"[{png_path}] Invalid lattice array. Skipped.")
        return

    N, M, L = lattice.shape
    plotter = pv.Plotter(off_screen=True, window_size=WINDOW_SIZE)
    box = pv.Box(bounds=(-0.5, N - 0.5, -0.5, M - 0.5, -0.5, L - 0.5))

    # Polymer
    polymer_ids = set(graph.nodes) if graph is not None else set()
    mask_polymer = np.isin(lattice, list(polymer_ids)) if polymer_ids else np.zeros(lattice.shape, dtype=bool)

    if np.any(mask_polymer):
        coords_poly = np.argwhere(mask_polymer).astype(np.float32)
        cloud_poly = pv.PolyData(coords_poly)
        plotter.add_mesh(
            cloud_poly,
            render_points_as_spheres=True,
            point_size=POINT_SIZE_POLYMER,
            color="black",
            opacity=1.0,
        )
        # Covalent bonds
        if graph is not None:
            vals_poly = lattice[mask_polymer].astype(int).ravel()
            id_to_idx = {int(node_id): idx for idx, node_id in enumerate(vals_poly)}
            lines_list = []
            for u, v, data in graph.edges(data=True):
                if data.get("type", "covalent") != "covalent":
                    continue
                iu = id_to_idx.get(int(u))
                iv = id_to_idx.get(int(v))
                if iu is None or iv is None:
                    continue
                lines_list.extend([2, iu, iv])
            if lines_list:
                lines = np.asarray(lines_list, dtype=np.int64)
                bond_mesh = pv.PolyData()
                bond_mesh.points = coords_poly
                bond_mesh.lines = lines
                plotter.add_mesh(bond_mesh, color="black", line_width=4, opacity=1.0)

    # Solvent (optional)
    if SHOW_SOLVENT:
        mask_solvent = (lattice > 0) & (~mask_polymer)
        if np.any(mask_solvent):
            coords_sol = np.argwhere(mask_solvent).astype(np.float32)
            vals_sol = lattice[mask_solvent].astype(np.float32).ravel()
            cloud_sol = pv.PolyData(coords_sol)
            cloud_sol["concentration"] = vals_sol
            plotter.add_mesh(
                cloud_sol,
                render_points_as_spheres=True,
                point_size=POINT_SIZE_SOLVENT,
                scalars="concentration",
                cmap=CMAP,
                opacity=SOLVENT_OPACITY,
                show_scalar_bar=True,
                scalar_bar_args={"title": "Solvent value"},
            )

    plotter.add_mesh(box, color="black", style="wireframe", line_width=1.0, opacity=0.35)
    plotter.reset_camera()
    plotter.camera_position = [
        (max(N, M, L) * 3.0, M * 0.5, L * 0.5),
        (N * 0.5, M * 0.5, L * 0.5),
        (0, 0, 1),
    ]
    plotter.add_axes()
    plotter.add_text(f"Timestep {idx * SNAP} / {STEPS}", font_size=14)
    plotter.screenshot(png_path)
    plotter.close()


def render_two_pngs(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    if not os.path.exists(traj_file):
        print(f"[{folder}] trajectories.npz not found.")
        return

    first_png = os.path.join(folder, "first_frame_polymer.png")
    last_png = os.path.join(folder, "last_frame_polymer.png")

    if not OVERWRITE and os.path.exists(first_png) and os.path.exists(last_png):
        print(f"[{folder}] PNGs already exist. Skipping.")
        return

    data = np.load(traj_file, allow_pickle=True)
    lattice_snapshots = [snap for snap in data["lattice_snapshots"] if isinstance(snap, np.ndarray) and snap.ndim == 3]

    if len(lattice_snapshots) == 0:
        print(f"[{folder}] No valid 3D snapshots found.")
        return

    first_idx = 0
    last_idx = max(0, len(lattice_snapshots) - 1)

    gel_graph_snapshots = data["gel_graph_snapshots"] if "gel_graph_snapshots" in data else None
    shared_graph = None

    if gel_graph_snapshots is None:
        graph_path = os.path.join(folder, "gel_graph_data.npz")
        if os.path.exists(graph_path):
            gdata = np.load(graph_path, allow_pickle=True)
            if "gel_graph" in gdata:
                try:
                    shared_graph = gdata["gel_graph"].item()
                    print(f"[{graph_path}] Using saved gel_graph object.")
                except Exception:
                    shared_graph = None
            elif {"node_ids", "node_coords", "edges", "edge_types"}.issubset(gdata.files):
                shared_graph = load_gel_graph_npz(graph_path)
                print(f"[{graph_path}] Reconstructed gel_graph from raw arrays.")
            else:
                print(f"[{graph_path}] No valid graph keys found.")

    g_first = get_graph_for_index(first_idx, gel_graph_snapshots, shared_graph)
    g_last = get_graph_for_index(last_idx, gel_graph_snapshots, shared_graph)
    SNAP, STEPS = read_params_from_files(folder)

    render_frame_to_png(lattice_snapshots[first_idx], g_first, first_png, first_idx, SNAP, STEPS)
    if last_idx != first_idx:
        render_frame_to_png(lattice_snapshots[last_idx], g_last, last_png, last_idx, SNAP, STEPS)
    else:
        from shutil import copyfile
        copyfile(first_png, last_png)

    print(f"[{folder}] Saved first_frame_polymer.png and last_frame_polymer.png")


# ------------------------
# Per folder processing with monitoring
# ------------------------

def process_directory(folder):
    process = psutil.Process(os.getpid())
    cpu_samples = []
    mem_samples = []
    start_time = time.time()

    print(f"\n[{folder}] 🟢 Starting render...")

    try:
        while True:
            cpu_samples.append(process.cpu_percent(interval=0.2))
            mem_samples.append(process.memory_info().rss / (1024 ** 2))
            if not process.is_running():
                break
            if len(cpu_samples) % 10 == 0:
                time.sleep(0.5)
            if not process.children():
                break
            time.sleep(0.2)
        render_two_pngs(folder)
    except Exception as e:
        print(f"[{folder}] ❌ Error: {e}")

    elapsed = time.time() - start_time
    if cpu_samples and mem_samples:
        avg_cpu = round(sum(cpu_samples) / len(cpu_samples), 1)
        peak_mem = round(max(mem_samples), 1)
    else:
        avg_cpu = 0.0
        peak_mem = 0.0

    print(f"[{folder}] ⏱ Render time: {str(timedelta(seconds=int(elapsed)))}")
    print(f"[{folder}] 📊 Peak RAM: {peak_mem} MB | Avg CPU: {avg_cpu}%\n")


# ---------------
# Entry point
# ---------------

if __name__ == "__main__":
    base_dir = os.getcwd()
    subdirs = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d)) and os.path.exists(os.path.join(base_dir, d, "trajectories.npz"))
    ]

    print(f"Found {len(subdirs)} folders with trajectories.npz")

    for folder in subdirs:
        process_directory(folder)

    print("✅ All PNG renders complete.")
