import os

# --- WICHTIGER FIX GEGEN ABSTURZ (OMP Error #15) ---
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import pandas as pd
from PIL import Image
from PIL.ExifTags import TAGS
from datetime import datetime

# --- HILFSFUNKTIONEN ---

def get_image_timestamp(image_path):
    try:
        image = Image.open(image_path)
        exif_data = image._getexif()
        if exif_data:
            for tag, value in exif_data.items():
                tag_name = TAGS.get(tag, tag)
                if tag_name == 'DateTimeOriginal':
                    return datetime.strptime(value, '%Y:%m:%d %H:%M:%S')
    except Exception as e:
        pass
    return None

def save_debug_image(roi_img, mask, ellipse_local, original_image_path, output_folder, count):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    base_name = os.path.basename(original_image_path)
    name, _ = os.path.splitext(base_name)
    
    debug_img = roi_img.copy()
    
    if ellipse_local is not None:
        cv2.ellipse(debug_img, ellipse_local, (0, 255, 0), 2)

    mask_bgr = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
    
    try:
        combined = np.hstack([debug_img, mask_bgr])
        
        h, w = combined.shape[:2]
        if h > 600:
            factor = 600 / h
            new_dim = (int(w * factor), 600)
            combined_resized = cv2.resize(combined, new_dim)
        else:
            combined_resized = combined

        debug_image_path = os.path.join(output_folder, f"{name}_debug_{count}.png")
        cv2.imwrite(debug_image_path, combined_resized)
        
    except Exception as e:
        print(f"Fehler beim Speichern des Debug-Bildes: {e}")

def update_plot(frame, time_seconds, diameters, ax):
    if len(time_seconds) > 0:
        ax.clear()
        ax.plot(time_seconds, diameters, marker='o', linestyle='-', markersize=4)
        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Diameter (pixels)')
        ax.set_title('Diameter vs. Time')
        ax.grid(True)
        plt.tight_layout()

# --- HAUPTFUNKTIONEN ZUR ERKENNUNG ---

def select_region_robust(image_directory):
    image_files = sorted(glob.glob(os.path.join(image_directory, '*.jpg')))
    if not image_files:
        print("Keine JPG-Bilder im Ordner gefunden!")
        return None, None

    initial_image_path = image_files[0]
    img = cv2.imread(initial_image_path)
    
    if img is None:
        print(f"Fehler: Konnte Bild nicht laden: {initial_image_path}")
        return None, None

    h_orig, w_orig = img.shape[:2]
    screen_max_h = 900
    scale_factor = 1.0
    
    if h_orig > screen_max_h:
        scale_factor = screen_max_h / h_orig
        new_w = int(w_orig * scale_factor)
        new_h = int(h_orig * scale_factor)
        display_img = cv2.resize(img, (new_w, new_h))
    else:
        display_img = img

    print("\n--- ANLEITUNG ---")
    print("1. Ziehe ein Rechteck um die Kugel.")
    print("   WICHTIG: Das Rechteck sollte den Stab oben beinhalten.")
    print("2. Bestätige mit LEERTASTE oder ENTER.")
    
    cv2.namedWindow("Bereich auswaehlen", cv2.WINDOW_NORMAL)
    cv2.resizeWindow("Bereich auswaehlen", display_img.shape[1], display_img.shape[0])
    
    try:
        rect = cv2.selectROI("Bereich auswaehlen", display_img, showCrosshair=True, fromCenter=False)
    except:
        return None, None
    finally:
        cv2.destroyWindow("Bereich auswaehlen")
        cv2.waitKey(1)

    x_small, y_small, w_small, h_small = rect

    if w_small == 0 or h_small == 0:
        return None, None

    x_real = int(x_small / scale_factor)
    y_real = int(y_small / scale_factor)
    w_real = int(w_small / scale_factor)
    h_real = int(h_small / scale_factor)
    
    x_real = max(0, x_real)
    y_real = max(0, y_real)
    w_real = min(w_orig - x_real, w_real)
    h_real = min(h_orig - y_real, h_real)

    return (x_real, y_real, w_real, h_real), image_files

def find_sphere_cut_rod_smooth(roi_img):
    """
    Erkennt die Kugel, entfernt den Stab und schleift den Höcker ab.
    """
    # 1. Graustufen & Blur
    gray = cv2.cvtColor(roi_img, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (9, 9), 2)

    # 2. Thresholding (50 hat gut funktioniert)
    _, mask = cv2.threshold(blurred, 50, 255, cv2.THRESH_BINARY)

    # --- STRATEGIE: Trennen & Abschleifen ---
    
    # A: Starkes Erodieren zum Brechen des Halses
    breaker_size = 17
    breaker_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (breaker_size, breaker_size))
    mask_eroded = cv2.erode(mask, breaker_kernel, iterations=2)
    
    # B: Nur das größte Objekt behalten
    contours_temp, _ = cv2.findContours(mask_eroded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours_temp:
        return None, 0, mask
        
    best_contour_temp = max(contours_temp, key=cv2.contourArea)
    mask_kugel_only = np.zeros_like(mask)
    cv2.drawContours(mask_kugel_only, [best_contour_temp], -1, 255, -1)

    # C: Dilatieren (Wiederherstellen)
    mask_dilated = cv2.dilate(mask_kugel_only, breaker_kernel, iterations=2)

    # --- NEU: Höcker abschleifen ---
    
    # D: "Abschleifen" (Leichte Erosion)
    # Ein kleinerer Kernel entfernt hervorstehende Spitzen (Höcker)
    shaver_size = 9
    shaver_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (shaver_size, shaver_size))
    mask_shaved = cv2.erode(mask_dilated, shaver_kernel, iterations=1)

    # E: "Polieren" (Leichtes Closing)
    # Glättet die Kante, ohne einen neuen großen Höcker zu bauen
    polisher_size = 15 
    polisher_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (polisher_size, polisher_size))
    mask_final = cv2.morphologyEx(mask_shaved, cv2.MORPH_CLOSE, polisher_kernel, iterations=1)


    # 4. Finale Konturen finden
    contours, _ = cv2.findContours(mask_final, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        return None, 0, mask_final

    best_contour = max(contours, key=cv2.contourArea)

    if cv2.contourArea(best_contour) < 50:
        return None, 0, mask_final

    # Ellipse fitten
    if len(best_contour) >= 5:
        ellipse = cv2.fitEllipse(best_contour)
        
        # Plausibilitäts-Check
        (center, (MA, ma), angle) = ellipse
        aspect_ratio = ma / MA if MA > 0 else 0
        
        # Strengerer Check: Wenn immer noch zu hoch, stimmt was nicht.
        if aspect_ratio > 1.3: 
             diameter = MA # Nur Breite nehmen
        else:
             diameter = (MA + ma) / 2 # Durchschnitt nehmen
        
        return ellipse, diameter, mask_final
    
    return None, 0, mask_final

def process_images(image_files, image_directory, roi_rect):
    # Neuer Name für den Kontrollordner
    mask_output_folder = os.path.join(image_directory, 'Kontrollbilder_GlatterSchnitt')
    
    diameters = []
    time_seconds = []
    first_timestamp = None
    
    x, y, w, h = roi_rect
    print(f"Verarbeite {len(image_files)} Bilder mit glattem Schnitt...")

    for count, image_path in enumerate(image_files):
        original_image = cv2.imread(image_path)
        if original_image is None: continue

        roi = original_image[y:y+h, x:x+w]
        if roi.size == 0: continue

        # Aufruf der NEUEN Funktion
        ellipse_local, diameter, debug_mask = find_sphere_cut_rod_smooth(roi)

        if ellipse_local and diameter > 0:
            diameters.append(diameter)
            
            timestamp = get_image_timestamp(image_path)
            ts_sec = 0
            if timestamp:
                if first_timestamp is None: first_timestamp = timestamp
                ts_sec = (timestamp - first_timestamp).total_seconds()
            time_seconds.append(ts_sec)

            print(f"[{count+1}] D ≈ {diameter:.2f} px")

        save_debug_image(roi, debug_mask, ellipse_local, image_path, mask_output_folder, count)

    if diameters:
        df = pd.DataFrame({'Time (seconds)': time_seconds, 'Diameter (pixels)': diameters})
        # Neuer Name für die Excel-Datei
        excel_path = os.path.join(image_directory, 'analysis_results_glatt.xlsx')
        df.to_excel(excel_path, index=False)
        print(f"\nErgebnisse: {excel_path}")
        print(f"Kontrollbilder: {mask_output_folder}")

    return time_seconds, diameters

def main():
    image_directory = input("Pfad zum Ordner mit Bildern: ").strip('"')
    roi_rect, image_files = select_region_robust(image_directory)

    if roi_rect and image_files:
        time_seconds, diameters = process_images(image_files, image_directory, roi_rect)
        cv2.destroyAllWindows()
        cv2.waitKey(1)

        if time_seconds:
            print("Zeige Diagramm...")
            fig, ax = plt.subplots(figsize=(10, 6))
            update_plot(0, time_seconds, diameters, ax)
            plt.show()
    else:
        print("Abbruch.")

if __name__ == "__main__":
    main()