"""
Analysis of Young's modulus for hydrogel particles.

Goal of this script:
- Read force–displacement curves of particles from an Excel file.
- For each particle, fit a Hertz contact model to the first part of the curve.
- Extract Young’s modulus E (and an uncertainty estimate) from the fit.
- Plot each fit for visual checking.
- Export results to a new Excel file (one output sheet per input sheet).
"""

import numpy as np
# Plotting: figures will open in your browser
import plotly.io as pio
pio.renderers.default='browser' 
import plotly.express as px
import plotly.graph_objects as go
from scipy.optimize import curve_fit # non-linear least squares fitting
import pandas as pd

# ----------------------------
# 1) Global constants/settings
# ----------------------------
Offset = 0.01                # Constant force offset added to the model (same units as your force data, typically N)
v = 0.33                     # Poisson's ratio (assumed constant for all particles)
end_distance_for_fit = 0.6   # Only fit data up to this displacement (mm)

# Optional: pick ONE particle name to export its raw+fit curve to a dedicated output sheet
selected_particle_name = "example"
selected_displacement = None
selected_force = None
selected_force_fit=None

# ----------------------------
# 2) Hertz model used for fit
# ----------------------------

# R is global on purpose:
# curve_fit expects a function f(x, param1, param2, ...) and will fit only those parameters.
# Here, we want to fit only E, while R changes for each particle.
R = 0

def HertzTheory(x,E):
       
    # Hertz theory function:

    # Inputs:
    #- x: displacement array (mm)
    #- E: Young's modulus (fit parameter)

    # Uses global constants:
    #- R: radius for the current particle (set just before curve_fit)
    #- v: Poisson’s ratio
    #- Offset: constant force offset

    #Output:
    #- F: predicted force values (same unit as your force data)
    
    F = Offset + ((4*(R)**(1/2))/3) * E/(1-v**2) * (x/2)**(3/2)
    return F

# ----------------------------
# 3) Load Excel input workbook
# ----------------------------
file_path = "Velocity_Hydrogel_cutted_Test.xlsx"   #Input Excel file created by your preprocessing script
xlsx = pd.ExcelFile(file_path)

# Read the sheet that contains one "Starting Distance" per particle measurement
# This is used to compute a radius R = starting_distance / 2 for the Hertz model.
data = pd.read_excel(file_path, sheet_name="Starting Distances")
starting_distances_list = data['Starting Distance'].values
# Replace missing values with 3.0 (mm). This is a fallback if a starting distance is missing.
starting_distances_list = np.nan_to_num(starting_distances_list, nan=3.0)


# ---------------------------------------------------------
# 4) Fit each particle in each sheet; write results to Excel
# ---------------------------------------------------------

# This list tracks how many particles were in each sheet already processed.
amount_measurements_sheet_lst = []
with pd.ExcelWriter("Youngs_Modulus_Particles.xlsx") as writer: 

    for idx, sheet_name in enumerate(xlsx.sheet_names): # ADJUST: you could restrict sheets here if needed 
        if sheet_name == "Starting Distances":
            print(f"Skipping sheet: {sheet_name}")
            continue  

        data = pd.read_excel(file_path, sheet_name=sheet_name)
        start_idx = int(np.sum(amount_measurements_sheet_lst))
        # The script assumes the sheet structure is:
        # Each particle uses exactly 2 columns: [Distance, Force]
        # So number_of_particles = number_of_columns / 2
        amount_measurements_sheet_lst.append(data.shape[1]/2)

        print(start_idx)
       
        Youngs_modulus_mean = []    # fitted E values (later scaled to kPa)
        Youngs_modulus_stabw = []   # standard error of E (later scaled to kPa)
        Particle_names = []         # particle name 
        R2 = []                     # R^2 quality measure per fit


        # ----------------------------
        # 5) Loop over particles in sheet
        # ----------------------------
        for ID in range(int(amount_measurements_sheet_lst[-1])): 
            # The script assumes particle data is laid out like this:
            # Column 0: distance for particle 0
            # Column 1: force    for particle 0
            # Column 2: distance for particle 1
            # Column 3: force    for particle 1
            # etc.
            #
            # It also assumes:
            # - Row 2 contains the particle name (in the distance column)
            # - Row 3 onward contains numeric data (distance/force values)
 
                particle_name = data.iloc[1, ID *2] 
                distance_col = data.iloc[3:, ID * 2]  # distance column for this particle
                force_col = data.iloc[3:, ID * 2 + 1]  # force column for this particle

                # Remove NaNs (some columns are padded with NaN in preprocessing)
                distance = distance_col.dropna().values
                force = force_col.dropna().values

                # Set the start of the measurement
                x_start_messung = distance[:]
                y_start_messung = force[:]
                x_start_messung_np = np.array(x_start_messung, dtype=float)

                # Find the first index where displacement exceeds the fitting limit (0.6 mm by default)
                # Assumption: there exists at least one point with x >= end_distance_for_fit
                end_fit = np.where(x_start_messung_np[:] >= end_distance_for_fit)[0][0]
                # Only fit the first part of the curve (small displacement regime)
                x_start_fit = x_start_messung[:end_fit]
                y_start_fit = y_start_messung[:end_fit]

                
                starting_distance_this_trial = starting_distances_list[start_idx + ID]
                print(starting_distance_this_trial)
                # Convert that starting distance to radius for Hertz model
                R = starting_distance_this_trial / 2

                # ----------------------------
                # 6) Fit Hertz model to data
                # ----------------------------

                # curve_fit finds E such that HertzTheory(x, E) best matches y (least squares)
                parameters, covariance = curve_fit(HertzTheory, x_start_fit, y_start_fit)
                fit_E = parameters[0]  # fitted Young's modulus E
                fit_F = HertzTheory(x_start_fit, fit_E)  # fitted force curve for comparison/plot

                # optional: Store data for the one selected particle (for later export)
                if particle_name == selected_particle_name:
                    selected_displacement = x_start_fit
                    selected_force = y_start_fit
                    selected_force_fit = fit_F

                # Plot data + fit for visual inspection
                fig = go.Figure()
                fig.add_trace(go.Scatter(x=x_start_fit, y=y_start_fit, name="original data", mode = "markers"))
                fig.add_trace(go.Line(x=x_start_fit, y=fit_F, name="fitted curve"))
                fig.update_layout(title=f"Force–Displacement Curve: {particle_name}", font_size=20, xaxis_title="Displacement / mm", yaxis_title="Force / N", template="presentation")
                fig.show()
            
                # ----------------------------
                # 7) Estimate uncertainty + fit quality
                # ----------------------------
                SE = np.sqrt(np.diag(covariance))
                SE_E = SE[0] # uncertainty estimate for E

                # Residual sum of squares: how far data deviates from fitted curve
                ss_res = np.sum((y_start_fit - fit_F) ** 2) 
                # Total sum of squares: overall variation 
                ss_tot = np.sum((y_start_fit - np.mean(fit_F)) ** 2)  # Total sum of squares
                # R^2: 1 means perfect fit
                r2 = 1 - (ss_res / ss_tot)
        
                # Store results; multiply by 1e3 so values are written as "kPa" (as labeled later)
                Youngs_modulus_mean.append(fit_E*10**3)
                Youngs_modulus_stabw.append(SE_E*10**3)
                Particle_names.append(particle_name)
                R2.append(r2)

                print(f"Particle {ID} ({particle_name}): Young's modulus = {(fit_E*10**3):.2f} kPa & Standard error = {(SE_E*10**3):.3f} & R^2: {(r2):.3f}")
        

        # ----------------------------
        # 8) Write per-sheet results table to output Excel
        # ----------------------------

        # Two header rows that will be placed above the data
        header_row_1 = ["Partikelname", "Youngs Modulus", "Standard deviation", "Bestimmtheitsmaß (R^2)"]
        header_row_2 = ["-", "kPa", "-", "-"]
        df = pd.DataFrame({
            'Partikelname': Particle_names,
            'Youngs Modulus (mean)': Youngs_modulus_mean,
            'Youngs Modulus (Standard deviation)': Youngs_modulus_stabw,
            'Bestimmtheitsmaß (R^2)': R2,
             })
        headers_df = pd.DataFrame([header_row_1, header_row_2])
        # Reset the column names to generic names (i.e., make them identical)
        headers_df.columns = range(headers_df.shape[1])
        df.columns = range(df.shape[1])
        # Put the header rows on top of the data rows
        combined_with_headers = pd.concat([headers_df, df], axis=0, ignore_index=True)

        # Write the combined data (with headers) for this sheet to the corresponding sheet in the output file
        combined_with_headers.to_excel(writer, sheet_name=sheet_name[:31], index=False, engine='openpyxl', header=False)

        # ----------------------------
        # 9) Export one selected particle (optional)
        # ----------------------------
        if selected_displacement is not None and selected_force is not None:
            df_selected = pd.DataFrame({
                "Displacement (mm)": selected_displacement,
                "Force (N)": selected_force,
                "Force Fit (N)": selected_force_fit
            })
            df_selected.to_excel(writer,
                             sheet_name="Selected Particle",
                             index=False,
                             engine="openpyxl")
        else:
            print(f"\n Warning: Selected particle '{selected_particle_name}' was not found.")

print()
