from smbus2 import SMBus
import time
import json
import os
import csv
from datetime import datetime
from collections import deque
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt

ADS1115_ADDR = 0x48
ADS1115_CONVERSION = 0x00
ADS1115_CONFIG = 0x01
CAL_FILE = "etape_calibration.txt"

CSV_FILE = "level_log.csv"

EMA_ALPHA = 0.2

ENABLE_PLOT = True
PLOT_WINDOW_SAMPLES = 300

# ADS1115 configuration:
# AIN0 single-ended, ±4.096V range, 128SPS, single-shot
CONFIG = (
    0x4000  # AIN0 single-ended
    | 0x0200  # ±4.096V range
    | 0x0100  # 128 samples per second
    | 0x8000  # Start single conversion
    | 0x0003  # Single-shot mode
)

# ------------------ READ VOLTAGE ------------------
def read_voltage(bus):
    time.sleep(0.1)
    config_bytes = [(CONFIG >> 8) & 0xFF, CONFIG & 0xFF]
    bus.write_i2c_block_data(ADS1115_ADDR, ADS1115_CONFIG, config_bytes)
    time.sleep(0.3)
    data = bus.read_i2c_block_data(ADS1115_ADDR, ADS1115_CONVERSION, 2)
    raw = (data[0] << 8) | data[1]
    if raw > 32767:
        raw -= 65536
    return raw * 4.096 / 32768.0

# ------------------ CALIBRATION I/O ------------------
def save_calibration(cal_dict):
    with open(CAL_FILE, "w") as f:
        json.dump(cal_dict, f)
    print(f"\n✅ Calibration saved to {CAL_FILE}\n")

def load_calibration():
    if not os.path.exists(CAL_FILE):
        return None
    with open(CAL_FILE, "r") as f:
        return json.load(f)

# ------------------ 5-POINT PIECEWISE CALIBRATION ------------------
def run_five_point_piecewise_calibration(bus):
    """
    Stores 5 points (V, mm). Conversion uses linear interpolation between
    adjacent points and clamps outside range.
    """
    print("\n--- Five-Point Calibration (piecewise linear, units: mm) ---")
    pts = []

    for i in range(1, 6):
        print(f"\nCalibration Point {i}:")
        target = float(input("Enter target level (mm): "))
        input("Position sensor at that level and press Enter to measure...")
        v = read_voltage(bus)
        print(f"Measured voltage: {v:.3f} V")
        pts.append((v, target))

    # sort by voltage increasing
    pts.sort(key=lambda x: x[0])

    # Basic sanity: voltages should be distinct
    if abs(pts[0][0] - pts[1][0]) < 1e-6 or abs(pts[1][0] - pts[2][0]) < 1e-6:
        print("\n❌ Calibration failed: two voltages are too close / identical. Try again with more separated levels.")
        return None

    cal = {
        "model": "piecewise_5",
        "unit": "mm",
        "points": [
            {"voltage": pts[0][0], "level_mm": pts[0][1]},
            {"voltage": pts[1][0], "level_mm": pts[1][1]},
            {"voltage": pts[2][0], "level_mm": pts[2][1]},
            {"voltage": pts[3][0], "level_mm": pts[3][1]},
            {"voltage": pts[4][0], "level_mm": pts[4][1]},
        ],
    }

    print("\n--- Calibration Complete ---")
    print("Stored points (sorted by voltage):")
    for p in cal["points"]:
        print(f"  V={p['voltage']:.4f}  ->  {p['level_mm']:.2f} mm")

    save_calibration(cal)
    return cal

# ------------------ LEVEL COMPUTATION ------------------
def lerp(x0, y0, x1, y1, x):
    # linear interpolation (or extrapolation if x outside [x0, x1])
    return y0 + (y1 - y0) * (x - x0) / (x1 - x0)

def voltage_to_level(v, cal):
    # Piecewise 5-point model
    if cal.get("model") == "piecewise_5" and "points" in cal:
        pts = cal["points"]
        v0, y0 = pts[0]["voltage"], pts[0]["level_mm"]
        v1, y1 = pts[1]["voltage"], pts[1]["level_mm"]
        v2, y2 = pts[2]["voltage"], pts[2]["level_mm"]
        v3, y3 = pts[3]["voltage"], pts[3]["level_mm"]
        v4, y4 = pts[4]["voltage"], pts[4]["level_mm"]

        # Clamp outside range to avoid runaway extrapolation
        if v <= v0:
            return lerp(v0, y0, v1, y1, v)
        if v >= v3:
            return lerp(v3, y3, v4, y4, v)

        # Interpolate within the correct segment
        if v <= v1:
            return lerp(v0, y0, v1, y1, v)
        else:
            if v <= v2:
                return lerp(v1, y1, v2, y2, v)
            else:
              if v <= v3:
                  return lerp(v2, y2, v3, y3, v)
              else:
                  return lerp(v3, y3, v4, y4, v)



    # Backward compatibility: quadratic
  #  if cal.get("model") == "poly2" and all(k in cal for k in ("a", "b", "c")):
     #  return cal["a"] * (v ** 2) + cal["b"] * v + cal["c"]

    # Backward compatibility: linear
  #  if "m" in cal and "b" in cal:
      #  return cal["m"] * v + cal["b"]

    raise ValueError("Calibration file is missing required fields.")

# ------------------ EMA FILTER ------------------
def ema_update(prev_ema, new_value, alpha):
    if prev_ema is None:
        return new_value
    return alpha * new_value + (1.0 - alpha) * prev_ema

# ------------------ CSV LOGGING ------------------
def init_csv(path):
    file_exists = os.path.exists(path)
    f = open(path, "a", newline="")
    writer = csv.writer(f)
    if not file_exists:
        writer.writerow(["timestamp_iso", "voltage_V", "level_raw_mm", "level_ema_mm"])
        f.flush()
    return f, writer

# ------------------ PLOTTING ------------------
def init_plot():
    plt.ion()
    fig, ax = plt.subplots()
    ax.set_title("Level Sensor (Raw vs EMA)")
    ax.set_xlabel("Sample")
    ax.set_ylabel("Level (mm)")

    raw_line, = ax.plot([], [], label="Raw")
    ema_line, = ax.plot([], [], label="EMA")
    ax.legend(loc="best")
    return fig, ax, raw_line, ema_line

def update_plot(ax, raw_line, ema_line, raw_vals, ema_vals):
    x = list(range(len(raw_vals)))
    raw_line.set_data(x, list(raw_vals))
    ema_line.set_data(x, list(ema_vals))
    ax.relim()
    ax.autoscale_view()
    plt.pause(0.001)

# ------------------ MAIN ------------------
with SMBus(1) as bus:
    print("=== eTape Sensor 5-Point Piecewise Calibration (ADS1115, mm) ===\n")

    cal = load_calibration()

    if cal and cal.get("model") == "piecewise_5":
        pts = cal["points"]
        choice = input(
            "Use saved 5-point piecewise calibration "
            f"(V={pts[0]['voltage']:.3f}->{pts[0]['level_mm']:.1f}mm, "
            f"{pts[1]['voltage']:.3f}->{pts[1]['level_mm']:.1f}mm, "
            f"{pts[2]['voltage']:.3f}->{pts[2]['level_mm']:.1f}mm, "
            f"{pts[3]['voltage']:.3f}->{pts[3]['level_mm']:.1f}mm, "
            f"{pts[4]['voltage']:.3f}->{pts[4]['level_mm']:.1f}mm)? [Y/n]: "
        ).lower()
        if choice == "n":
            cal = run_five_point_piecewise_calibration(bus)
    else:
        if cal:
            print("Found old/other calibration. Switching to 5-point piecewise calibration.")
        else:
            print("No saved calibration found.")
        cal = run_five_point_piecewise_calibration(bus)

    if not cal:
        raise SystemExit("Calibration not available. Exiting.")

    csv_f, csv_writer = init_csv(CSV_FILE)
    print(f"\nLogging to CSV: {CSV_FILE}")

    if ENABLE_PLOT:
        fig, ax, raw_line, ema_line = init_plot()
        raw_vals = deque(maxlen=PLOT_WINDOW_SAMPLES)
        ema_vals = deque(maxlen=PLOT_WINDOW_SAMPLES)

    print("\nStarting continuous readout (Ctrl+C to stop)...\n")

    ema_level = None

    try:
        while True:
            v = read_voltage(bus)
            level_raw = voltage_to_level(v, cal)
            ema_level = ema_update(ema_level, level_raw, EMA_ALPHA)

            print(
                f"Voltage: {v:.3f} V  |  Level(raw): {level_raw:.2f} mm  |  Level(EMA): {ema_level:.2f} mm"
            )

            ts = datetime.now().isoformat(timespec="seconds")
            csv_writer.writerow([ts, f"{v:.6f}", f"{level_raw:.3f}", f"{ema_level:.3f}"])
            csv_f.flush()

            if ENABLE_PLOT:
                raw_vals.append(level_raw)
                ema_vals.append(ema_level)
                update_plot(ax, raw_line, ema_line, raw_vals, ema_vals)

            time.sleep(1)

    except KeyboardInterrupt:
        print("\nExiting. Goodbye!")
    finally:
        try:
            csv_f.close()
        except Exception:
            pass
        if ENABLE_PLOT:
            plt.ioff()
            plt.show()
