"""
Kinetic analysis and segmented Arrhenius plots for wet wollastonite carbonation.

This script calculates kinetic constants using:
1. Shrinking-core model (surface reaction control)
2. Shrinking-core model (diffusion control)
3. Avrami model with corrected rate constant kA^(1/n)

It also generates segmented Arrhenius results and publication-style plots.

Author: Sahra Homaee
"""

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import linregress


# Gas constant, J mol-1 K-1
R = 8.314

# Output folder
OUTPUT_DIR = Path("kinetic_analysis_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)


# ---------------------------------------------------------------------
# Input data
# ---------------------------------------------------------------------

data = {
    "time_min": [5, 10, 15, 20, 30, 45, 60, 75, 90, 120, 150, 180],
    "X_25C": [0.043341, 0.051269, 0.058912, 0.061579, 0.063951, 0.075748, 0.086619, 0.095742, 0.102369, 0.108997, 0.124176, 0.133521],
    "X_45C": [0.056129, 0.072583, 0.089848, 0.101103, 0.120857, 0.153228, 0.185618, 0.209421, 0.238705, 0.275583, 0.310424, 0.332808],
    "X_60C": [0.058942, 0.083102, 0.101484, 0.128261, 0.169194, 0.216349, 0.250666, 0.289397, 0.320954, 0.384117, 0.428329, 0.454231],
    "X_75C": [0.057090, 0.095925, 0.128422, 0.154090, 0.205940, 0.261846, 0.309737, 0.364050, 0.408463, 0.485447, 0.536270, 0.573774],
    "X_90C": [0.032017, 0.045657, 0.080522, 0.101834, 0.155221, 0.226481, 0.301375, 0.346319, 0.389315, 0.491529, 0.560442, 0.618824],
}

temperatures_C = np.array([25, 45, 60, 75, 90], dtype=float)

segments = {
    "Stage I: 25-45 C": [25, 45],
    "Stage II: 45-75 C": [45, 60, 75],
    "Stage III: 75-90 C": [75, 90],
}

models = {
    "SCM surface reaction": "ks_SCM_surface",
    "SCM diffusion": "kd_SCM_diffusion",
    "Avrami corrected kA^(1/n)": "kA_corrected",
}


# ---------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------

def linear_fit(x_values, y_values):
    """Perform linear regression and return fit statistics."""
    fit = linregress(x_values, y_values)
    y_fit = fit.slope * x_values + fit.intercept
    residuals = y_values - y_fit
    rss = np.sum(residuals**2)
    rmse = np.sqrt(rss / (len(y_values) - 2))

    return {
        "slope": fit.slope,
        "intercept": fit.intercept,
        "R2": fit.rvalue**2,
        "RSS": rss,
        "RMSE": rmse,
        "y_fit": y_fit,
        "residuals": residuals,
    }


def arrhenius_fit(temperature_C, k_values):
    """Calculate Arrhenius parameters from kinetic constants."""
    temperature_C = np.array(temperature_C, dtype=float)
    temperature_K = temperature_C + 273.15
    inverse_temperature = 1 / temperature_K
    ln_k = np.log(k_values)

    fit = linregress(inverse_temperature, ln_k)
    activation_energy = -fit.slope * R / 1000

    return {
        "temperature_C": temperature_C,
        "temperature_K": temperature_K,
        "inverse_temperature": inverse_temperature,
        "ln_k": ln_k,
        "slope": fit.slope,
        "intercept": fit.intercept,
        "R2": fit.rvalue**2,
        "Ea_kJ_mol": activation_energy,
    }


# ---------------------------------------------------------------------
# Kinetic constant calculation
# ---------------------------------------------------------------------

def calculate_kinetic_constants(data_frame, temperatures):
    """Calculate kinetic constants for each temperature."""
    rows = []

    for temperature in temperatures:
        time = data_frame["time_min"].to_numpy(dtype=float)
        conversion = data_frame[f"X_{int(temperature)}C"].to_numpy(dtype=float)

        valid = (time > 0) & (conversion > 0) & (conversion < 1)
        time = time[valid]
        conversion = conversion[valid]

        # SCM surface reaction:
        # 1 - (1 - X)^(1/3) = ks t
        y_surface = 1 - (1 - conversion) ** (1 / 3)
        surface_fit = linear_fit(time, y_surface)

        # SCM diffusion:
        # 1 - 3(1 - X)^(2/3) + 2(1 - X) = kd t
        y_diffusion = 1 - 3 * (1 - conversion) ** (2 / 3) + 2 * (1 - conversion)
        diffusion_fit = linear_fit(time, y_diffusion)

        # Avrami:
        # ln[-ln(1 - X)] = ln(kA) + n ln(t)
        ln_time = np.log(time)
        y_avrami = np.log(-np.log(1 - conversion))
        avrami_fit = linear_fit(ln_time, y_avrami)

        n_avrami = avrami_fit["slope"]
        kA = np.exp(avrami_fit["intercept"])
        kA_corrected = kA ** (1 / n_avrami)

        rows.append(
            {
                "Temperature_C": temperature,
                "Temperature_K": temperature + 273.15,
                "ks_SCM_surface": surface_fit["slope"],
                "kd_SCM_diffusion": diffusion_fit["slope"],
                "kA_Avrami": kA,
                "n_Avrami": n_avrami,
                "kA_corrected": kA_corrected,
                "R2_SCM_surface": surface_fit["R2"],
                "R2_SCM_diffusion": diffusion_fit["R2"],
                "R2_Avrami": avrami_fit["R2"],
                "RSS_SCM_surface": surface_fit["RSS"],
                "RSS_SCM_diffusion": diffusion_fit["RSS"],
                "RSS_Avrami": avrami_fit["RSS"],
                "RMSE_SCM_surface": surface_fit["RMSE"],
                "RMSE_SCM_diffusion": diffusion_fit["RMSE"],
                "RMSE_Avrami": avrami_fit["RMSE"],
            }
        )

    return pd.DataFrame(rows)


def calculate_segmented_arrhenius(results, model_dict, segment_dict):
    """Calculate activation energies for defined temperature segments."""
    segmented_rows = []

    for model_name, k_column in model_dict.items():
        for segment_name, temperatures in segment_dict.items():
            temperatures = np.array(temperatures, dtype=float)

            k_values = np.array(
                [
                    results.loc[results["Temperature_C"] == temp, k_column].values[0]
                    for temp in temperatures
                ]
            )

            fit = arrhenius_fit(temperatures, k_values)

            segmented_rows.append(
                {
                    "Model": model_name,
                    "Stage": segment_name,
                    "Ea_kJ_mol": fit["Ea_kJ_mol"],
                    "R2": fit["R2"],
                    "Slope": fit["slope"],
                    "Intercept": fit["intercept"],
                }
            )

    return pd.DataFrame(segmented_rows)


# ---------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------

def plot_segmented_arrhenius(
    results,
    k_column,
    model_label,
    file_name,
    show_segment_lines=True,
):
    """Generate segmented Arrhenius plot."""
    temperature_C = results["Temperature_C"].to_numpy()
    temperature_K = results["Temperature_K"].to_numpy()

    inverse_temperature = 1 / temperature_K
    ln_k = np.log(results[k_column].to_numpy())

    x_90 = 1 / (90 + 273.15)
    x_75 = 1 / (75 + 273.15)
    x_45 = 1 / (45 + 273.15)
    x_25 = 1 / (25 + 273.15)

    y_min = ln_k.min() - 0.6
    y_max = ln_k.max() + 0.6

    fig, ax = plt.subplots(figsize=(7.2, 5.2))

    # Shaded temperature regions
    ax.axvspan(x_90, x_75, alpha=0.20, zorder=0)
    ax.axvspan(x_75, x_45, alpha=0.12, zorder=0)
    ax.axvspan(x_45, x_25, alpha=0.20, zorder=0)

    ax.scatter(
        inverse_temperature,
        ln_k,
        marker="s",
        s=70,
        edgecolor="black",
        linewidth=0.4,
        zorder=3,
        label="Kinetic constants",
    )

    if show_segment_lines:
        for segment_name, temperatures in segments.items():
            temperatures = np.array(temperatures, dtype=float)
            k_values = np.array(
                [
                    results.loc[results["Temperature_C"] == temp, k_column].values[0]
                    for temp in temperatures
                ]
            )

            fit = arrhenius_fit(temperatures, k_values)

            x_fit = np.linspace(
                fit["inverse_temperature"].min(),
                fit["inverse_temperature"].max(),
                100,
            )
            y_fit = fit["slope"] * x_fit + fit["intercept"]

            ax.plot(
                x_fit,
                y_fit,
                linestyle="--",
                linewidth=1.8,
                zorder=2,
                label=f"{segment_name}: Ea = {fit['Ea_kJ_mol']:.1f} kJ mol$^{{-1}}$",
            )

    ax.text((x_90 + x_75) / 2, y_max - 0.05, "Stage III", ha="center", va="top")
    ax.text((x_75 + x_45) / 2, y_max - 0.05, "Stage II", ha="center", va="top")
    ax.text((x_45 + x_25) / 2, y_max - 0.05, "Stage I", ha="center", va="top")

    for x_value, y_value, temp in zip(inverse_temperature, ln_k, temperature_C):
        ax.text(x_value, y_value + 0.12, f"{int(temp)} C", ha="center", fontsize=9)

    ax.set_xlabel("1/T (K$^{-1}$)")
    ax.set_ylabel("ln(k)")
    ax.set_title(f"Segmented Arrhenius plot: {model_label}")
    ax.set_xlim(x_90 - 0.00003, x_25 + 0.00003)
    ax.set_ylim(y_min, y_max)
    ax.legend(fontsize=8)

    fig.tight_layout()
    fig.savefig(OUTPUT_DIR / file_name, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_avrami_exponent(results):
    """Plot Avrami exponent as a function of temperature."""
    fig, ax = plt.subplots(figsize=(6.5, 4.5))
    ax.plot(results["Temperature_C"], results["n_Avrami"], marker="o")
    ax.set_xlabel("Temperature (C)")
    ax.set_ylabel("Avrami exponent, n")
    ax.set_title("Evolution of Avrami exponent with temperature")
    fig.tight_layout()
    fig.savefig(OUTPUT_DIR / "Avrami_n_vs_temperature.png", dpi=300, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# Main workflow
# ---------------------------------------------------------------------

def main():
    df = pd.DataFrame(data)

    raw_data_file = OUTPUT_DIR / "conversion_data_used_for_kinetic_analysis.xlsx"
    df.to_excel(raw_data_file, index=False)

    results = calculate_kinetic_constants(df, temperatures_C)
    results.to_excel(OUTPUT_DIR / "Table_Kinetic_constants.xlsx", index=False)

    segmented_results = calculate_segmented_arrhenius(results, models, segments)
    segmented_results.to_excel(OUTPUT_DIR / "Table_Arrhenius_results.xlsx", index=False)

    plot_segmented_arrhenius(
        results,
        "kA_corrected",
        "Avrami corrected kA^(1/n)",
        "Arrhenius_Avrami_corrected.png",
        show_segment_lines=True,
    )

    plot_segmented_arrhenius(
        results,
        "ks_SCM_surface",
        "SCM surface reaction",
        "Arrhenius_SCM_surface.png",
        show_segment_lines=True,
    )

    plot_segmented_arrhenius(
        results,
        "kd_SCM_diffusion",
        "SCM diffusion",
        "Arrhenius_SCM_diffusion.png",
        show_segment_lines=True,
    )

    plot_avrami_exponent(results)

    print(f"Analysis finished. Results were saved in: {OUTPUT_DIR.resolve()}")


if __name__ == "__main__":
    main()
