import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import glob
import os
import pandas as pd
from PIL import Image
from PIL.ExifTags import TAGS
from datetime import datetime

ref_point = []
color_points = []
cropping = False
scale_x = 1
scale_y = 1
image = None

def shape_selection(event, x, y, flags, param):
    global ref_point, cropping, scale_x, scale_y

    if event == cv2.EVENT_LBUTTONDOWN:
        ref_point = [(x, y)]
        cropping = True
    elif event == cv2.EVENT_LBUTTONUP:
        ref_point.append((x, y))
        cropping = False
        
        x_start, y_start = [int(coord / scale_x) for coord in ref_point[0]]
        x_end, y_end = [int(coord / scale_x) for coord in ref_point[1]]
        
        if x_start < x_end and y_start < y_end:
            cv2.rectangle(image, (x_start, y_start), (x_end, y_end), (0, 255, 0), 2)
            cv2.imshow("Select Region", image)

def color_selection(event, x, y, flags, param):
    global color_points, scale_x, scale_y

    if event == cv2.EVENT_LBUTTONDOWN:
        scaled_x = int(x / scale_x)
        scaled_y = int(y / scale_y)
        color_points.append((scaled_x, scaled_y))
        if len(color_points) == 5:
            cv2.setMouseCallback("Select Color", lambda *args : None)
            cv2.destroyWindow("Select Color")

def display_image(image, window_name="image"):
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    screen_width, screen_height = get_screen_resolution()
    resized_image, scale_x, scale_y = resize_with_aspect_ratio(image, screen_width, screen_height)
    cv2.resizeWindow(window_name, resized_image.shape[1], resized_image.shape[0])
    cv2.imshow(window_name, resized_image)
    return scale_x, scale_y

def get_screen_resolution():
    import ctypes
    user32 = ctypes.windll.user32
    screen_width = user32.GetSystemMetrics(0)
    screen_height = user32.GetSystemMetrics(1)
    return screen_width, screen_height

def resize_with_aspect_ratio(image, target_width, target_height):
    h, w = image.shape[:2]
    aspect_ratio = w / h

    if w / h > target_width / target_height:
        new_width = target_width
        new_height = int(target_width / aspect_ratio)
    else:
        new_height = target_height
        new_width = int(target_height * aspect_ratio)

    resized_image = cv2.resize(image, (new_width, new_height))
    return resized_image, new_width / w, new_height / h

def increase_saturation(image, scale=1.5):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    s = cv2.multiply(s, scale).astype(np.uint8)
    hsv = cv2.merge([h, s, v])
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

def adjust_brightness_contrast(image, brightness=0, contrast=0):
    if brightness != 0:
        if brightness > 0:
            shadow = brightness
            highlight = 255
        else:
            shadow = 0
            highlight = 255 + brightness
        alpha_b = (highlight - shadow) / 255
        gamma_b = shadow
        image = cv2.addWeighted(image, alpha_b, image, 0, gamma_b)

    if contrast != 0:
        f = 131 * (contrast + 127) / (127 * (131 - contrast))
        alpha_c = f
        gamma_c = 127 * (1 - f)
        image = cv2.addWeighted(image, alpha_c, image, 0, gamma_c)

    return image

def calculate_average_color(image, points):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    colors = [hsv[y, x] for x, y in points]
    avg_color = np.mean(colors, axis=0).astype(int)
    return avg_color

def extract_color_region(image, is_white=True):
    adjusted_image = adjust_brightness_contrast(image, brightness=50, contrast=20)


    hsv = cv2.cvtColor(adjusted_image, cv2.COLOR_BGR2HSV)


    if is_white:
        lower_color = np.array([0, 0, 180])
        upper_color = np.array([180, 60, 255])
    else:
        raise ValueError("Currently, only white color region extraction is configured.")


    mask = cv2.inRange(hsv, lower_color, upper_color)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.erode(mask, kernel, iterations=1)
    mask = cv2.dilate(mask, kernel, iterations=2)

    return mask

def calculate_diameter(contour):
    if len(contour) >= 5:
        ellipse = cv2.fitEllipse(contour)
        diameter = max(ellipse[1])  
        return diameter
    return 0

def calculate_centroid(contour):
    M = cv2.moments(contour)
    if M["m00"] != 0:
        cx = int(M["m10"] / M["m00"])
        cy = int(M["m01"] / M["m00"])
    else:
        cx, cy = 0, 0
    return (cx, cy)

def get_image_timestamp(image_path):
    image = Image.open(image_path)
    exif_data = image._getexif()
    if exif_data:
        for tag, value in exif_data.items():
            tag_name = TAGS.get(tag, tag)
            if tag_name == 'DateTimeOriginal':
                return datetime.strptime(value, '%Y:%m:%d %H:%M:%S')
    return None

def save_mask_image(mask, original_image_path, output_folder, count):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    base_name = os.path.basename(original_image_path)
    name, _ = os.path.splitext(base_name)
    mask_image_path = os.path.join(output_folder, f"{name}_mask_{count}.png")

    cv2.imwrite(mask_image_path, mask)

def update_plot(frame, time_seconds, diameters, ax):
    if len(time_seconds) > 0:
        ax.clear()
        ax.plot(time_seconds, diameters, marker='o')
        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Diameter (pixels)')
        ax.set_title('Diameter vs. Time')
        ax.grid(True)
        plt.xticks(rotation=45)
        plt.tight_layout()

def select_color_and_region(initial_image_path):
    global image, ref_point, color_points, scale_x, scale_y

    initial_image = cv2.imread(initial_image_path)
    if initial_image is None:
        print(f"Error: Could not load image from path: {initial_image_path}")
        return None, None, None

    image = initial_image.copy()

    cv2.namedWindow("Select Color", cv2.WINDOW_NORMAL)
    scale_x, scale_y = display_image(image, "Select Color")
    cv2.setMouseCallback("Select Color", color_selection)

    while len(color_points) < 5:
        if cv2.waitKey(1) & 0xFF == ord('q'):
            cv2.destroyWindow("Select Color")
            return None, None, None

    avg_color = calculate_average_color(image, color_points)

    cv2.namedWindow("Select Region", cv2.WINDOW_NORMAL)
    cv2.setMouseCallback("Select Region", shape_selection)

    while True:
        scale_x, scale_y = display_image(image, "Select Region")
        key = cv2.waitKey(1) & 0xFF

        if key == ord("r"):
            image = initial_image.copy()
            ref_point = []

        elif key == ord("c"):
            if len(ref_point) == 2:
                cv2.destroyWindow("Select Region")
                break

        elif key == ord("q"):
            cv2.destroyWindow("Select Region")
            return None, None, None

    x_start, y_start = [int(coord / scale_x) for coord in ref_point[0]]
    x_end, y_end = [int(coord / scale_x) for coord in ref_point[1]]
    roi = image[y_start:y_end, x_start:x_end]
    color_region = extract_color_region(roi)

    while True:
        cv2.imshow("Detected Color Region", color_region)
        key = cv2.waitKey(1) & 0xFF

        if key == ord('c'):
            cv2.destroyWindow("Detected Color Region")
            return avg_color, ref_point, (scale_x, scale_y)
        elif key == ord('r'):
            color_points = []
            cv2.destroyWindow("Detected Color Region")
            return select_color_and_region(initial_image_path)

def process_images(image_directory, avg_color, ref_point, scale_factors):
    scale_x, scale_y = scale_factors
    image_files = sorted(glob.glob(os.path.join(image_directory, '*.jpg')))

    mask_output_folder = os.path.join(image_directory, 'Kontrollbilder')
    if not os.path.exists(mask_output_folder):
        os.makedirs(mask_output_folder)

    diameters = []
    time_seconds = []

    first_timestamp = None

    for count, image_path in enumerate(image_files):
        original_image = cv2.imread(image_path)
        if original_image is None:
            print(f"Error: Could not load image from path: {image_path}")
            continue

        timestamp = get_image_timestamp(image_path)
        if timestamp:
            if first_timestamp is None:
                first_timestamp = timestamp
            timestamp_seconds = (timestamp - first_timestamp).total_seconds()
            timestamp_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
        else:
            timestamp_seconds = None
            timestamp_str = "Unknown"

        x_start, y_start = [int(coord / scale_x) for coord in ref_point[0]]
        x_end, y_end = [int(coord / scale_x) for coord in ref_point[1]]

        roi = original_image[y_start:y_end, x_start:x_end]
        if roi.size == 0:
            print("Error: Region of interest has size 0.")
            continue

        color_region = extract_color_region(roi)

        if (count + 1) % 5 == 0:
            save_mask_image(color_region, image_path, mask_output_folder, (count + 1) // 5)

        contours, _ = cv2.findContours(color_region, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        max_diameter = 0

        for contour in contours:
            contour = contour + np.array([x_start, y_start])
            diameter = calculate_diameter(contour)
            if diameter > max_diameter:
                max_diameter = diameter

        if max_diameter > 0:
            diameters.append(max_diameter)
            time_seconds.append(timestamp_seconds)
            print(f"Image: {image_path}, Diameter: {max_diameter} pixels")

    df = pd.DataFrame({
        'Time (seconds)': time_seconds,
        'Diameter (pixels)': diameters
    })
    excel_file_path = os.path.join(image_directory, 'analysis_results.xlsx')
    df.to_excel(excel_file_path, index=False)
    print(f"Results saved to {excel_file_path}")

    return time_seconds, diameters

def main():
    initial_image_path = input("Bitte den Pfad zum Ausgangsbild eingeben: ")
    image_directory = input("Bitte den Pfad zum Ordner mit den Bilddateien eingeben: ")

    avg_color, ref_point, scale_factors = select_color_and_region(initial_image_path)

    if avg_color is not None and ref_point is not None and scale_factors is not None:
        time_seconds, diameters = process_images(image_directory, avg_color, ref_point, scale_factors)

        fig, ax = plt.subplots()
        ani = animation.FuncAnimation(fig, update_plot, fargs=(time_seconds, diameters, ax), interval=1000)
        plt.show()

if __name__ == "__main__":
    main()