#!/usr/bin/env python
# coding: utf-8

import os
import time
import psutil
import numpy as np
import pyvista as pv
import networkx as nx
from datetime import timedelta

# ---------------------------
# SETTINGS
# ---------------------------
OVERWRITE = True
WINDOW_SIZE = (1000, 1000)
POINT_SIZE_POLYMER = 10
POINT_SIZE_SOLVENT = 6
CMAP = "coolwarm"
SOLVENT_OPACITY = 0.35
SHOW_SOLVENT = False  # <<< toggle this if you ever want to see solvent again

# Hide covalent bonds that cross periodic boundaries.
# Such bonds are physically short through PBC, but look like long artificial lines in the rendered box.
HIDE_PBC_CROSSING_BONDS = True

# Small tolerance for floating point comparisons
PBC_TOL = 1e-8


# ---------------------------
# Helpers
# ---------------------------

def load_gel_graph_npz(filename):
    data = np.load(filename, allow_pickle=True)
    gel_graph = nx.Graph()

    for node, coord in zip(data["node_ids"], data["node_coords"]):
        gel_graph.add_node(int(node), coord=tuple(coord))

    for (u, v), t in zip(data["edges"], data["edge_types"]):
        gel_graph.add_edge(int(u), int(v), type=str(t))

    return gel_graph


def read_params_from_files(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    final_file = os.path.join(folder, "final.npz")

    data = np.load(traj_file, allow_pickle=True)
    lattice_snapshots = data["lattice_snapshots"]
    num_frames = len(lattice_snapshots)

    print_interval = 1
    total_steps = None

    if os.path.exists(final_file):
        try:
            fin = np.load(final_file, allow_pickle=True)

            if "print_interval" in fin:
                print_interval = int(fin["print_interval"])

            if "total_steps" in fin:
                total_steps = int(fin["total_steps"])

        except Exception:
            pass

    if total_steps is None:
        total_steps = (num_frames - 1) * print_interval

    return print_interval, total_steps


def get_graph_for_index(i, gel_graph_snapshots, shared_graph):
    graph = None

    if gel_graph_snapshots is not None and len(gel_graph_snapshots) > i:
        try:
            graph = gel_graph_snapshots[i].item()
        except Exception:
            graph = None

    if graph is None:
        graph = shared_graph

    return graph


def minimum_image_delta(p1, p2, box_lengths):
    """
    Return both the raw displacement and the minimum image displacement.

    If the raw displacement differs from the minimum image displacement,
    the bond crosses a periodic boundary.
    """
    delta_raw = p2 - p1
    delta_mic = delta_raw.copy()

    for k, box in enumerate(box_lengths):
        if delta_mic[k] > box / 2.0:
            delta_mic[k] -= box
        elif delta_mic[k] < -box / 2.0:
            delta_mic[k] += box

    return delta_raw, delta_mic


def bond_crosses_pbc(p1, p2, box_lengths, tol=PBC_TOL):
    """
    Detect whether a bond crosses a periodic boundary.

    In normal coordinates, a PBC bond appears as a very long line.
    With the minimum image convention, the same bond becomes short.
    """
    delta_raw, delta_mic = minimum_image_delta(p1, p2, box_lengths)

    return not np.allclose(delta_raw, delta_mic, atol=tol)


# ---------------------------
# Rendering
# ---------------------------

def render_frame_to_png(lattice, graph, png_path, idx, SNAP, STEPS):
    if not isinstance(lattice, np.ndarray) or lattice.ndim != 3:
        print(f"[{png_path}] Invalid lattice array. Skipped.")
        return

    N, M, L = lattice.shape
    box_lengths = np.array([N, M, L], dtype=np.float32)

    plotter = pv.Plotter(off_screen=True, window_size=WINDOW_SIZE)

    box = pv.Box(
        bounds=(
            -0.5, N - 0.5,
            -0.5, M - 0.5,
            -0.5, L - 0.5,
        )
    )

    # Polymer
    polymer_ids = set(graph.nodes) if graph is not None else set()

    if polymer_ids:
        mask_polymer = np.isin(lattice, list(polymer_ids))
    else:
        mask_polymer = np.zeros(lattice.shape, dtype=bool)

    if np.any(mask_polymer):
        coords_poly = np.argwhere(mask_polymer).astype(np.float32)
        vals_poly = lattice[mask_polymer].astype(int).ravel()

        cloud_poly = pv.PolyData(coords_poly)

        plotter.add_mesh(
            cloud_poly,
            render_points_as_spheres=True,
            point_size=POINT_SIZE_POLYMER,
            color="black",
            opacity=1.0,
        )

        # Covalent bonds
        if graph is not None:
            id_to_idx = {
                int(node_id): local_idx
                for local_idx, node_id in enumerate(vals_poly)
            }

            lines_list = []
            hidden_pbc_bonds = 0
            drawn_bonds = 0

            for u, v, data in graph.edges(data=True):
                if data.get("type", "covalent") != "covalent":
                    continue

                iu = id_to_idx.get(int(u))
                iv = id_to_idx.get(int(v))

                if iu is None or iv is None:
                    continue

                p1 = coords_poly[iu]
                p2 = coords_poly[iv]

                if HIDE_PBC_CROSSING_BONDS and bond_crosses_pbc(p1, p2, box_lengths):
                    hidden_pbc_bonds += 1
                    continue

                lines_list.extend([2, iu, iv])
                drawn_bonds += 1

            if hidden_pbc_bonds > 0:
                print(f"[{png_path}] Hidden PBC crossing bonds: {hidden_pbc_bonds}")

            print(f"[{png_path}] Drawn covalent bonds: {drawn_bonds}")

            if lines_list:
                lines = np.asarray(lines_list, dtype=np.int64)

                bond_mesh = pv.PolyData()
                bond_mesh.points = coords_poly
                bond_mesh.lines = lines

                plotter.add_mesh(
                    bond_mesh,
                    color="black",
                    line_width=4,
                    opacity=1.0,
                )

    # Solvent, optional
    if SHOW_SOLVENT:
        mask_solvent = (lattice > 0) & (~mask_polymer)

        if np.any(mask_solvent):
            coords_sol = np.argwhere(mask_solvent).astype(np.float32)
            vals_sol = lattice[mask_solvent].astype(np.float32).ravel()

            cloud_sol = pv.PolyData(coords_sol)
            cloud_sol["concentration"] = vals_sol

            plotter.add_mesh(
                cloud_sol,
                render_points_as_spheres=True,
                point_size=POINT_SIZE_SOLVENT,
                scalars="concentration",
                cmap=CMAP,
                opacity=SOLVENT_OPACITY,
                show_scalar_bar=True,
                scalar_bar_args={"title": "Solvent value"},
            )

    plotter.add_mesh(
        box,
        color="black",
        style="wireframe",
        line_width=1.0,
        opacity=0.35,
    )

    plotter.reset_camera()

    plotter.camera_position = [
        (max(N, M, L) * 3.0, M * 0.5, L * 0.5),
        (N * 0.5, M * 0.5, L * 0.5),
        (0, 0, 1),
    ]

    plotter.add_axes()
    plotter.add_text(f"Timestep {idx * SNAP} / {STEPS}", font_size=14)

    plotter.screenshot(png_path)
    plotter.close()


def render_two_pngs(folder):
    traj_file = os.path.join(folder, "trajectories.npz")

    if not os.path.exists(traj_file):
        print(f"[{folder}] trajectories.npz not found.")
        return

    first_png = os.path.join(folder, "first_frame_polymer.png")
    last_png = os.path.join(folder, "last_frame_polymer.png")

    if not OVERWRITE and os.path.exists(first_png) and os.path.exists(last_png):
        print(f"[{folder}] PNGs already exist. Skipping.")
        return

    data = np.load(traj_file, allow_pickle=True)

    lattice_snapshots = [
        snap
        for snap in data["lattice_snapshots"]
        if isinstance(snap, np.ndarray) and snap.ndim == 3
    ]

    if len(lattice_snapshots) == 0:
        print(f"[{folder}] No valid 3D snapshots found.")
        return

    first_idx = 0
    last_idx = max(0, len(lattice_snapshots) - 1)

    if "gel_graph_snapshots" in data:
        gel_graph_snapshots = data["gel_graph_snapshots"]
    else:
        gel_graph_snapshots = None

    shared_graph = None

    if gel_graph_snapshots is None:
        graph_path = os.path.join(folder, "gel_graph_data.npz")

        if os.path.exists(graph_path):
            gdata = np.load(graph_path, allow_pickle=True)

            if "gel_graph" in gdata:
                try:
                    shared_graph = gdata["gel_graph"].item()
                    print(f"[{graph_path}] Using saved gel_graph object.")
                except Exception:
                    shared_graph = None

            elif {"node_ids", "node_coords", "edges", "edge_types"}.issubset(gdata.files):
                shared_graph = load_gel_graph_npz(graph_path)
                print(f"[{graph_path}] Reconstructed gel_graph from raw arrays.")

            else:
                print(f"[{graph_path}] No valid graph keys found.")

    g_first = get_graph_for_index(first_idx, gel_graph_snapshots, shared_graph)
    g_last = get_graph_for_index(last_idx, gel_graph_snapshots, shared_graph)

    SNAP, STEPS = read_params_from_files(folder)

    render_frame_to_png(
        lattice_snapshots[first_idx],
        g_first,
        first_png,
        first_idx,
        SNAP,
        STEPS,
    )

    if last_idx != first_idx:
        render_frame_to_png(
            lattice_snapshots[last_idx],
            g_last,
            last_png,
            last_idx,
            SNAP,
            STEPS,
        )
    else:
        from shutil import copyfile
        copyfile(first_png, last_png)

    print(f"[{folder}] Saved first_frame_polymer.png and last_frame_polymer.png")


# ------------------------
# Per folder processing with monitoring
# ------------------------

def process_directory(folder):
    process = psutil.Process(os.getpid())

    cpu_samples = []
    mem_samples = []
    start_time = time.time()

    print(f"\n[{folder}] Starting render...")

    try:
        while True:
            cpu_samples.append(process.cpu_percent(interval=0.2))
            mem_samples.append(process.memory_info().rss / (1024 ** 2))

            if not process.is_running():
                break

            if len(cpu_samples) % 10 == 0:
                time.sleep(0.5)

            if not process.children():
                break

            time.sleep(0.2)

        render_two_pngs(folder)

    except Exception as e:
        print(f"[{folder}] Error: {e}")

    elapsed = time.time() - start_time

    if cpu_samples and mem_samples:
        avg_cpu = round(sum(cpu_samples) / len(cpu_samples), 1)
        peak_mem = round(max(mem_samples), 1)
    else:
        avg_cpu = 0.0
        peak_mem = 0.0

    print(f"[{folder}] Render time: {str(timedelta(seconds=int(elapsed)))}")
    print(f"[{folder}] Peak RAM: {peak_mem} MB | Avg CPU: {avg_cpu}%\n")


# ---------------
# Entry point
# ---------------

if __name__ == "__main__":
    base_dir = os.getcwd()

    subdirs = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d))
        and os.path.exists(os.path.join(base_dir, d, "trajectories.npz"))
    ]

    print(f"Found {len(subdirs)} folders with trajectories.npz")

    for folder in subdirs:
        process_directory(folder)

    print("All PNG renders complete.")
