#!/usr/bin/env python
# coding: utf-8

import os
import time
import random
import numpy as np
import argparse
import psutil
import threading
from datetime import datetime, timedelta
from numba import njit, set_num_threads

# ---------------- ARGUMENT PARSING ----------------
parser = argparse.ArgumentParser(description="3DCA polymer–solvent simulation")
parser.add_argument("--L", type=int, required=True, help="Chain length (number of monomers)")
parser.add_argument("--W", type=int, required=True, help="Initial solvent concentration per cell")
parser.add_argument("--Epp", type=float, required=True, help="Polymer–polymer interaction energy")
parser.add_argument("--Es1p", type=float, required=True, help="Polymer–water interaction energy")
parser.add_argument("--maxc", type=int, default=30, help="Maximum cell solvent concentration")
parser.add_argument("--steps", type=int, default=1000, help="Number of simulation steps")
parser.add_argument("--interval", type=int, default=100, help="Snapshot interval")
args = parser.parse_args()

# ---------------- SETTINGS ----------------
nproc = 1
N, M = 50, 50
Es1p = args.Es1p
Es2p = 0.0
first_node_number = 1000
max_cell_concentration = args.maxc
steps = args.steps
print_interval = args.interval
polymer_fraction_moving = 0.25
covalent_spring_constant = 30.0
max_allowed_stretch = 2.0

set_num_threads(nproc)

# ---------------- RESOURCE MONITOR ----------------
stop_monitor = False
cpu_usage = []
mem_usage = []

def monitor_resources():
    """Thread to record CPU and RAM usage every 0.5 s."""
    while not stop_monitor:
        cpu_usage.append(psutil.cpu_percent(interval=None))
        mem_usage.append(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024)
        time.sleep(0.5)

# ---------------- INITIALIZATION ----------------
def initialization(N, M, L, w_cell_concentration):
    lattice = np.full((N, M, L), w_cell_concentration, dtype=np.int32)
    x, y = N // 2, M // 2

    coords = np.zeros((L, 3), dtype=np.int32)
    node_ids = np.arange(first_node_number, first_node_number + L, dtype=np.int32)
    for idx, z in enumerate(range(L)):
        lattice[x, y, z] = node_ids[idx]
        coords[idx] = np.array([x, y, z], dtype=np.int32)

    bonds = np.column_stack([node_ids[:-1], node_ids[1:]])
    edge_types = np.array(['covalent'] * bonds.shape[0])
    return lattice, coords, node_ids, bonds, edge_types

# ---------------- HELPERS ----------------
@njit(inline='always')
def von_neumann_neighborhood(i, j, k, N, M, L):
    return np.array([
        [(i - 1) % N, j, k],
        [(i + 1) % N, j, k],
        [i, (j - 1) % M, k],
        [i, (j + 1) % M, k],
        [i, j, (k - 1) % L],
        [i, j, (k + 1) % L]
    ], dtype=np.int32)

@njit(fastmath=True)
def cell_type(value):
    return 1 if value >= first_node_number else 0

# ---------------- ENERGY FUNCTIONS ----------------
@njit(fastmath=True)
def interaction_energy(cell, neighbors, Es1p, Epp, Es2p):
    energy = 0.0
    if cell_type(cell) == 0:
        water = min(cell, max_cell_concentration)
        ethanol = max_cell_concentration - water
        xs1p = water / max_cell_concentration
        xs2p = ethanol / max_cell_concentration
        for nb in neighbors:
            if cell_type(nb) == 1:
                energy += Es1p * xs1p + Es2p * xs2p
    else:
        for nb in neighbors:
            if cell_type(nb) == 1:
                energy += Epp
            else:
                water = min(nb, max_cell_concentration)
                ethanol = max_cell_concentration - water
                xs1p = water / max_cell_concentration
                xs2p = ethanol / max_cell_concentration
                energy += Es1p * xs1p + Es2p * xs2p
    return energy

@njit(fastmath=True)
def bond_hard_wall_energy(node_idx, coords, bonds):
    """Exact hard geometric constraint (identical to old NetworkX version)."""
    E = 0.0
    node_id = first_node_number + node_idx
    allowed = (1.0, np.sqrt(2.0), np.sqrt(3.0))
    tol = 1e-6
    INF_PENALTY = 1e6

    for b in range(bonds.shape[0]):
        if bonds[b, 0] == node_id:
            i2 = bonds[b, 1] - first_node_number
        elif bonds[b, 1] == node_id:
            i2 = bonds[b, 0] - first_node_number
        else:
            continue

        dx = coords[node_idx, 0] - coords[i2, 0]
        dy = coords[node_idx, 1] - coords[i2, 1]
        dz = coords[node_idx, 2] - coords[i2, 2]
        dist = (dx*dx + dy*dy + dz*dz) ** 0.5

        if not (
            abs(dist - allowed[0]) < tol
            or abs(dist - allowed[1]) < tol
            or abs(dist - allowed[2]) < tol
        ):
            E += INF_PENALTY
    return E

# ---------------- MAIN SIMULATION ----------------
def main_code(N, M, L, Epp, Es1p, Es2p, lattice, coords, node_ids, bonds, steps, print_interval):
    lattice_snapshots = []
    step = 0

    def attempt_move(node_idx):
        coord = tuple(coords[node_idx])
        i, j, k = coord
        neigh = von_neumann_neighborhood(i, j, k, N, M, L)
        solv = [tuple(n) for n in neigh if cell_type(lattice[tuple(n)]) == 0]
        if not solv:
            return None
        sc = random.choice(solv)
        sc_val = lattice[sc]
        pc_val = lattice[i, j, k]

        neigh_pc = np.array([lattice[tuple(n)] for n in neigh if tuple(n) != sc])
        neigh_sc = np.array([lattice[tuple(n)] for n in von_neumann_neighborhood(*sc, N, M, L) if tuple(n) != (i, j, k)])

        e0 = interaction_energy(pc_val, neigh_pc, Es1p, Epp, Es2p) \
            + interaction_energy(sc_val, neigh_sc, Es1p, Epp, Es2p) \
            + bond_hard_wall_energy(node_idx, coords, bonds)

        orig = coords[node_idx].copy()
        coords[node_idx] = np.array(sc)
        e1 = interaction_energy(sc_val, neigh_pc, Es1p, Epp, Es2p) \
            + interaction_energy(pc_val, neigh_sc, Es1p, Epp, Es2p) \
            + bond_hard_wall_energy(node_idx, coords, bonds)
        coords[node_idx] = orig

        too_far = False
        node_id = first_node_number + node_idx
        for b in range(bonds.shape[0]):
            if bonds[b, 0] == node_id:
                i2 = bonds[b, 1] - first_node_number
            elif bonds[b, 1] == node_id:
                i2 = bonds[b, 0] - first_node_number
            else:
                continue
            dx = coords[node_idx, 0] - coords[i2, 0]
            dy = coords[node_idx, 1] - coords[i2, 1]
            dz = coords[node_idx, 2] - coords[i2, 2]
            dist = (dx*dx + dy*dy + dz*dz) ** 0.5
            if dist > max_allowed_stretch:
                too_far = True
                break

        coords[node_idx] = orig

        if too_far:
            return None

        if np.random.rand() < np.exp(-(e1 - e0)):
            return (node_idx, (i, j, k), sc)
        return None

    while step <= steps:
        if step % print_interval == 0:
            print(f"L={L}, step={step}, Epp={Epp}, Es1p={Es1p}")
            lattice_snapshots.append(lattice.copy())

        all_indices = list(range(coords.shape[0]))
        random.shuffle(all_indices)
        moved = set()

        for idx in all_indices:
            if (idx - 1) in moved or (idx + 1) in moved:
                continue
            r = attempt_move(idx)
            if r is None:
                continue
            node_idx, pc, sc = r
            lattice[pc], lattice[sc] = lattice[sc], lattice[pc]
            coords[node_idx] = np.array(sc)
            moved.add(idx)

        step += 1
    return lattice_snapshots, coords, bonds

# ---------------- SAVE ----------------
def save_gel_graph_npz(node_ids, coords, bonds, edge_types, filename):
    np.savez_compressed(
        filename,
        node_ids=node_ids.astype(np.int32),
        node_coords=coords.astype(np.int32),
        edges=bonds.astype(np.int32),
        edge_types=edge_types
    )

# ---------------- RUN ----------------
if __name__ == "__main__":
    start = time.time()
    monitor_thread = threading.Thread(target=monitor_resources)
    monitor_thread.start()

    date_str = datetime.now().strftime("%Y%m%d_%H%M%S")

    L_value = args.L
    W_value = args.W
    Epp_value = args.Epp
    Es1p_value = args.Es1p

    trajectories_dir = f"trajectories_{date_str}_L{L_value}_W{W_value}_Epp{Epp_value}_Es1p{Es1p_value}"
    os.makedirs(trajectories_dir, exist_ok=True)

    lattice, coords, node_ids, bonds, edge_types = initialization(N, M, L_value, W_value)
    lattice_snapshots, final_coords, final_bonds = main_code(
        N, M, L_value, Epp_value, Es1p_value, Es2p, lattice, coords, node_ids, bonds, steps, print_interval
    )

    np.savez_compressed(
        f"{trajectories_dir}/final.npz",
        lattice=lattice.copy(),
        Epp=Epp_value, Es1p=Es1p_value,
        total_nodes=L_value,
        steps=steps,
        print_interval=print_interval
    )
    np.savez_compressed(
        f"{trajectories_dir}/trajectories.npz",
        lattice_snapshots=np.array(lattice_snapshots, dtype=object)
    )

    save_gel_graph_npz(node_ids, final_coords, final_bonds, edge_types, f"{trajectories_dir}/gel_graph_data.npz")
    open(os.path.join(trajectories_dir, "mpsetfile"), "w").close()

    with open(f"{trajectories_dir}/settings.txt", "w") as f:
        f.write(f"# Simulation settings\nnproc = {nproc}\n")

    # Stop monitoring and summarize
    stop_monitor = True
    monitor_thread.join()

    elapsed = time.time() - start
    avg_cpu = np.mean(cpu_usage) if cpu_usage else 0
    peak_cpu = np.max(cpu_usage) if cpu_usage else 0
    avg_mem = np.mean(mem_usage) if mem_usage else 0
    peak_mem = np.max(mem_usage) if mem_usage else 0

    print(f"✅ Finished nproc={nproc}, L={L_value}, W={W_value}, Epp={Epp_value}, Es1p={Es1p_value}, saved to {trajectories_dir}")
    print(f"✅ Elapsed: {str(timedelta(seconds=int(elapsed)))}")
    print("\n📊 Resource usage summary:")
    print(f"   Average CPU: {avg_cpu:.1f}%")
    print(f"   Peak CPU:    {peak_cpu:.1f}%")
    print(f"   Average RAM: {avg_mem:.1f} MB")
    print(f"   Peak RAM:    {peak_mem:.1f} MB")
