#!/usr/bin/env python
# coding: utf-8

import os
import numpy as np
import pyvista as pv
import networkx as nx
import psutil
import time
from datetime import timedelta
import threading

# ---------------------------
# SETTINGS
# ---------------------------
OVERWRITE = False

WINDOW_SIZE = (1000, 1000)
FRAMERATE = 10

POINT_SIZE_POLYMER = 10
LINE_WIDTH_BOND = 4

BOX_OPACITY = 0.35
BOX_LINE_WIDTH = 1.0

# Hide covalent bonds that cross periodic boundaries.
# Such bonds are physically short through PBC, but they look like long artificial lines across the box.
HIDE_PBC_CROSSING_BONDS = True

# Small tolerance for floating point comparisons
PBC_TOL = 1e-8


# ---------------------------
# Helpers
# ---------------------------

def load_gel_graph_npz(filename):
    data = np.load(filename, allow_pickle=True)
    gel_graph = nx.Graph()

    for node, coord in zip(data["node_ids"], data["node_coords"]):
        gel_graph.add_node(int(node), coord=tuple(coord))

    for (u, v), t in zip(data["edges"], data["edge_types"]):
        gel_graph.add_edge(int(u), int(v), type=str(t))

    return gel_graph


def read_params_from_files(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    final_file = os.path.join(folder, "final.npz")

    data = np.load(traj_file, allow_pickle=True)
    lattice_snapshots = data["lattice_snapshots"]
    num_frames = len(lattice_snapshots)

    print_interval = 1
    total_steps = None

    if os.path.exists(final_file):
        try:
            fin = np.load(final_file, allow_pickle=True)

            if "print_interval" in fin:
                print_interval = int(fin["print_interval"])

            if "total_steps" in fin:
                total_steps = int(fin["total_steps"])
            elif "steps" in fin:
                total_steps = int(fin["steps"])

        except Exception:
            pass

    if total_steps is None:
        total_steps = (num_frames - 1) * print_interval

    return print_interval, total_steps


def get_graph_for_index(i, gel_graph_snapshots, shared_graph):
    graph = None

    if gel_graph_snapshots is not None and len(gel_graph_snapshots) > i:
        try:
            graph = gel_graph_snapshots[i].item()
        except Exception:
            graph = None

    if graph is None:
        graph = shared_graph

    return graph


def minimum_image_delta(p1, p2, box_lengths):
    """
    Return the raw displacement and the minimum image displacement.

    If they differ, the bond crosses a periodic boundary.
    """
    delta_raw = p2 - p1
    delta_mic = delta_raw.copy()

    for k, box in enumerate(box_lengths):
        if delta_mic[k] > box / 2.0:
            delta_mic[k] -= box
        elif delta_mic[k] < -box / 2.0:
            delta_mic[k] += box

    return delta_raw, delta_mic


def bond_crosses_pbc(p1, p2, box_lengths, tol=PBC_TOL):
    """
    Detect whether a bond crosses a periodic boundary.

    In stored coordinates, a PBC bond appears as a long line.
    Under the minimum image convention, the same bond becomes short.
    """
    delta_raw, delta_mic = minimum_image_delta(p1, p2, box_lengths)

    return not np.allclose(delta_raw, delta_mic, atol=tol)


# -----------------------------------
# 3D animation export, polymer only
# -----------------------------------

def export_lattice_3d_animation_polymer(
    lattice_snapshots,
    gel_graph_snapshots,
    shared_graph,
    filename,
    SNAP,
    STEPS
):
    snapshots = [
        snap
        for snap in lattice_snapshots
        if isinstance(snap, np.ndarray) and snap.ndim == 3
    ]

    if not snapshots:
        print(f"[{filename}] No valid 3D snapshots found.")
        return

    N, M, L = snapshots[0].shape
    box_lengths = np.array([N, M, L], dtype=np.float32)

    plotter = pv.Plotter(off_screen=True, window_size=WINDOW_SIZE)
    plotter.open_movie(filename, framerate=FRAMERATE)

    print(f"[{filename}] Rendering polymer only with PBC bond filtering...")

    box = pv.Box(
        bounds=(
            -0.5, N - 0.5,
            -0.5, M - 0.5,
            -0.5, L - 0.5,
        )
    )

    total_hidden_pbc_bonds = 0
    total_drawn_bonds = 0

    for i, lattice in enumerate(snapshots):
        plotter.clear()

        graph = get_graph_for_index(i, gel_graph_snapshots, shared_graph)
        polymer_ids = set(graph.nodes) if graph is not None else set()

        if polymer_ids:
            mask_polymer = np.isin(lattice, list(polymer_ids))
        else:
            mask_polymer = np.zeros(lattice.shape, dtype=bool)

        if np.any(mask_polymer):
            coords_poly = np.argwhere(mask_polymer).astype(np.float32)
            vals_poly = lattice[mask_polymer].astype(int).ravel()

            cloud_poly = pv.PolyData(coords_poly)

            plotter.add_mesh(
                cloud_poly,
                render_points_as_spheres=True,
                point_size=POINT_SIZE_POLYMER,
                color="black",
                opacity=1.0,
            )

            if graph is not None:
                id_to_idx = {
                    int(node_id): local_idx
                    for local_idx, node_id in enumerate(vals_poly)
                }

                lines_list = []
                hidden_pbc_bonds = 0
                drawn_bonds = 0

                for u, v, data in graph.edges(data=True):
                    if data.get("type", "covalent") != "covalent":
                        continue

                    iu = id_to_idx.get(int(u))
                    iv = id_to_idx.get(int(v))

                    if iu is None or iv is None:
                        continue

                    p1 = coords_poly[iu]
                    p2 = coords_poly[iv]

                    if HIDE_PBC_CROSSING_BONDS and bond_crosses_pbc(p1, p2, box_lengths):
                        hidden_pbc_bonds += 1
                        continue

                    lines_list.extend([2, iu, iv])
                    drawn_bonds += 1

                total_hidden_pbc_bonds += hidden_pbc_bonds
                total_drawn_bonds += drawn_bonds

                if lines_list:
                    lines = np.asarray(lines_list, dtype=np.int64)

                    bond_mesh = pv.PolyData()
                    bond_mesh.points = coords_poly
                    bond_mesh.lines = lines

                    plotter.add_mesh(
                        bond_mesh,
                        color="black",
                        line_width=LINE_WIDTH_BOND,
                        opacity=1.0,
                    )

        plotter.add_mesh(
            box,
            color="black",
            style="wireframe",
            line_width=BOX_LINE_WIDTH,
            opacity=BOX_OPACITY,
        )

        plotter.reset_camera()
        plotter.camera_position = [
            (max(N, M, L) * 3.0, M * 0.5, L * 0.5),
            (N * 0.5, M * 0.5, L * 0.5),
            (0, 0, 1),
        ]

        plotter.add_axes()
        plotter.add_text(f"Timestep {i * SNAP} / {STEPS}", font_size=14)

        plotter.write_frame()

        if i % 50 == 0:
            print(
                f"[{filename}] Frame {i + 1}/{len(snapshots)} | "
                f"drawn bonds: {total_drawn_bonds} | "
                f"hidden PBC bonds: {total_hidden_pbc_bonds}"
            )

    plotter.close()

    print(f"[{filename}] Saved polymer-only PBC video.")
    print(f"[{filename}] Total drawn covalent bonds over all frames: {total_drawn_bonds}")
    print(f"[{filename}] Total hidden PBC crossing bonds over all frames: {total_hidden_pbc_bonds}")


# ------------------------
# Resource monitoring
# ------------------------

def monitor_resource_usage(func, *args, **kwargs):
    process = psutil.Process(os.getpid())

    cpu_samples = []
    mem_samples = []
    start_time = time.time()

    stop_flag = False

    def sampler():
        while not stop_flag:
            try:
                cpu_samples.append(process.cpu_percent(interval=None))
                mem_samples.append(process.memory_info().rss / (1024 ** 2))
            except psutil.NoSuchProcess:
                break
            time.sleep(1)

    thread = threading.Thread(target=sampler, daemon=True)
    thread.start()

    func(*args, **kwargs)

    stop_flag = True
    thread.join()

    elapsed = time.time() - start_time

    avg_cpu = np.mean(cpu_samples) if cpu_samples else 0.0
    max_cpu = np.max(cpu_samples) if cpu_samples else 0.0
    avg_mem = np.mean(mem_samples) if mem_samples else 0.0
    max_mem = np.max(mem_samples) if mem_samples else 0.0

    return elapsed, avg_cpu, max_cpu, avg_mem, max_mem


# ------------------------
# Sequential processing
# ------------------------

def process_directory(folder):
    traj_file = os.path.join(folder, "trajectories.npz")
    output_file = os.path.join(folder, "trajectories_all_polymer_PBC.mp4")

    if not os.path.exists(traj_file):
        print(f"[{folder}] trajectories.npz not found.")
        return None

    if os.path.exists(output_file) and not OVERWRITE:
        print(f"[{folder}] Polymer PBC video already exists, skipping.")
        return None

    print(f"[{folder}] Rendering polymer-only PBC video...")

    try:
        SNAP, STEPS = read_params_from_files(folder)

        data = np.load(traj_file, allow_pickle=True)
        lattice_snapshots = data["lattice_snapshots"]

        if "gel_graph_snapshots" in data:
            gel_graph_snapshots = data["gel_graph_snapshots"]
        else:
            gel_graph_snapshots = None

        shared_graph = None

        graph_path = os.path.join(folder, "gel_graph_data.npz")

        if os.path.exists(graph_path):
            gdata = np.load(graph_path, allow_pickle=True)

            if "gel_graph" in gdata:
                try:
                    shared_graph = gdata["gel_graph"].item()
                    print(f"[{graph_path}] Using saved gel_graph object.")
                except Exception:
                    shared_graph = None

            elif {"node_ids", "node_coords", "edges", "edge_types"}.issubset(gdata.files):
                shared_graph = load_gel_graph_npz(graph_path)
                print(f"[{graph_path}] Reconstructed gel_graph from raw arrays.")

            else:
                print(f"[{graph_path}] No valid graph keys found.")

        elapsed, avg_cpu, max_cpu, avg_mem, max_mem = monitor_resource_usage(
            export_lattice_3d_animation_polymer,
            lattice_snapshots,
            gel_graph_snapshots,
            shared_graph,
            output_file,
            SNAP,
            STEPS,
        )

        print(f"[{folder}] Render time: {str(timedelta(seconds=int(elapsed)))}")
        print(
            f"[{folder}] Memory avg={avg_mem:.1f} MB, max={max_mem:.1f} MB | "
            f"CPU avg={avg_cpu:.1f}%, max={max_cpu:.1f}%"
        )

        return elapsed, avg_cpu, max_cpu, avg_mem, max_mem

    except Exception as e:
        print(f"[{folder}] Error: {e}")
        return None


# ---------------
# Entry point
# ---------------

if __name__ == "__main__":
    base_dir = os.getcwd()

    subdirs = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d))
        and os.path.exists(os.path.join(base_dir, d, "trajectories.npz"))
    ]

    print(f"Found {len(subdirs)} folders with trajectories.npz")

    results = []

    start_global = time.time()

    for folder in subdirs:
        res = process_directory(folder)

        if res:
            results.append(res)

    if results:
        elapsed_all = time.time() - start_global
        arr = np.array(results)

        print("")
        print("==============================")
        print(f"All PBC videos rendered in {timedelta(seconds=int(elapsed_all))}")
        print(f"Average CPU: {arr[:, 1].mean():.1f}% | Max CPU: {arr[:, 2].max():.1f}%")
        print(f"Average RAM: {arr[:, 3].mean():.1f} MB | Max RAM: {arr[:, 4].max():.1f} MB")
        print("==============================")
    else:
        print("No videos were rendered.")
