from smbus2 import SMBus, i2c_msg
import time, json, os, smbus2 as smbus, csv
from datetime import datetime
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt

# ------------------ I2C ADDRESSES ------------------
ADS1115_ADDR = 0x48
PUMP_ADDR = 0x6D
CAL_FILE = "etape_calibration.txt"

# ------------------ ADS1115 CONFIG ------------------
ADS1115_CONVERSION = 0x00
ADS1115_CONFIG = 0x01
CONFIG = (
    0x4000 | 0x0200 | 0x0100 | 0x8000 | 0x0003
)

# ------------------ FILTER SETTINGS ------------------
# EMA smoothing factor:
#   closer to 0.0 => stronger smoothing (slower)
#   closer to 1.0 => weaker smoothing (faster)
EMA_ALPHA = 0.2

# ------------------ ADS1115 FUNCTIONS ------------------
def read_voltage(bus):
    config_bytes = [(CONFIG >> 8) & 0xFF, CONFIG & 0xFF]
    bus.write_i2c_block_data(ADS1115_ADDR, ADS1115_CONFIG, config_bytes)
    time.sleep(0.3)
    data = bus.read_i2c_block_data(ADS1115_ADDR, ADS1115_CONVERSION, 2)
    raw = (data[0] << 8) | data[1]
    if raw > 32767:
        raw -= 65536
    voltage = raw * 4.096 / 32768.0
    return voltage

def load_calibration():
    if not os.path.exists(CAL_FILE):
        print(" No calibration file found. Run 5-point calibration first.")
        raise SystemExit(1)
    with open(CAL_FILE, "r") as f:
        cal = json.load(f)

    # Accept either:
    # - 5-point piecewise: {"model":"piecewise_5","points":[{"voltage":..,"level_mm":..},...]}
    # - legacy linear: {"m":..,"b":..}
    if cal.get("model") == "piecewise_5":
        pts = cal.get("points", [])
        if len(pts) != 5:
            print(" Calibration file has model=piecewise_5 but does not contain exactly 5 points.")
            raise SystemExit(1)
        # Ensure sorted by voltage
        pts_sorted = sorted(pts, key=lambda p: p["voltage"])
        cal["points"] = pts_sorted
        return cal

   # if "m" in cal and "b" in cal:
       # return cal

    print(" Calibration file format not recognized. Run 5-point calibration again.")
    raise SystemExit(1)

def lerp(x0, y0, x1, y1, x):
    return y0 + (y1 - y0) * (x - x0) / (x1 - x0)

def voltage_to_level(v, cal):
    """
    5-point piecewise calibration with extrapolation outside the range:
      - below v0 -> extrapolate using segment (v0,v1)
      - above v5 -> extrapolate using segment (v3,v4)
    """
    # 5-point piecewise
    if cal.get("model") == "piecewise_5" and "points" in cal:
        pts = cal["points"]
        v0, y0 = pts[0]["voltage"], pts[0]["level_mm"]
        v1, y1 = pts[1]["voltage"], pts[1]["level_mm"]
        v2, y2 = pts[2]["voltage"], pts[2]["level_mm"]
        v3, y3 = pts[3]["voltage"], pts[3]["level_mm"]
        v4, y4 = pts[4]["voltage"], pts[4]["level_mm"]

        # Extrapolate below range using first segment
        if v <= v0:
            return lerp(v0, y0, v1, y1, v)

        if v >= v4:
            return lerp(v3, y3, v4, y4, v)

        # Interpolate within the correct segment
        if v <= v1:
            return lerp(v0, y0, v1, y1, v)
        else:
            if v <= v2:
                return lerp(v1, y1, v2, y2, v)
            else:
                if v <= v3:
                    return lerp(v2, y2, v3, y3, v)
                else:
                    return lerp(v3, y3, v4, y4, v)

    # Legacy linear
  #  return cal["m"] * v + cal["b"]

def ema_update(prev_ema, new_value, alpha):
    if prev_ema is None:
        return new_value
    return alpha * new_value + (1.0 - alpha) * prev_ema

# ------------------ PUMP FUNCTIONS ------------------
def raw_send_command(cmd):
    data = bytes(cmd, "ascii") + b'\x00'
    msg = i2c_msg.write(PUMP_ADDR, data)
    with SMBus(1) as bus:
        bus.i2c_rdwr(msg)
    time.sleep(0.3)

def send_and_print(cmd, wait=0.5):
    raw_send_command(cmd)
    time.sleep(wait)

# ------------------ MAIN ------------------
def main():
    print(" Starting PID Control with Real-Time Plot...\n")
    bus = smbus.SMBus(1)
    cal = load_calibration()
    send_and_print("*OK,1")

    # ---- PI parameters ----
    target = float(input("Enter target level (mm): "))
    Kp = float(input("Enter proportional gain (Kp): "))
    Ki = float(input("Enter integral gain (Ki): "))
    Kd = float(input("Enter derivative gain (Kd): "))
    Ts = 1.0

    # ---- Prepare log and plot ----
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_filename = f"pi_log_{timestamp}.csv"
    with open(log_filename, "w", newline="") as log_file:
        writer = csv.writer(log_file)
        # Added: LevelRaw_mm and LevelEMA_mm
        writer.writerow(["Timestamp", "LevelRaw_mm", "LevelEMA_mm", "Error_mm", "PumpOutput_mLmin"])

    print(f" Logging data to {log_filename}")
    print(" Starting real-time plot (close window or Ctrl+C to stop)...")

    plt.ion()
    fig, ax = plt.subplots()
    times, levels_ema, levels_raw, setpoints = [], [], [], []
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Level (mm)")
    line_meas, = ax.plot([], [], label="Measured Level (EMA)", color="tab:blue")
    line_raw, = ax.plot([], [], label="Measured Level (Raw)", color="tab:gray", linestyle=":")
    line_set, = ax.plot([], [], label="Setpoint", color="tab:red", linestyle="--")
    ax.legend()
    plt.show(block=False)

    # ---- PI variables ----
    integral = 0.0
    last_error = 0.0
    output = 0.0
    derivative = 0.0
    start_time = time.time()
    last_time = start_time
    pump = True

    # ---- EMA state ----
    level_ema = None

    try:
        while True:
            v = read_voltage(bus)

            # raw level from calibration
            level_raw = voltage_to_level(v, cal)

            # EMA filtered level (used for control)
            level_ema = ema_update(level_ema, level_raw, EMA_ALPHA)

            # controller uses filtered signal
            level = level_ema

            error = target - level
            dt = time.time() - last_time
            last_time = time.time()

            if error <= 0:
                send_and_print("x")
                integral = 0
                derivative = 0
                output = 0

            if error >= 100:
                send_and_print("dc,440")
                pump = True

            else:
                if pump is True:
                    send_and_print("x")
                    pump = False

                integral += error * dt
                derivative = (error - last_error) / dt if dt > 0 else 0.0
                last_error = error

                output = Kp * error + Ki * integral + Kd * derivative
                output = max(0, min(output, 440))

                if 0 < output < 10:
                    output = 0.001

                if output == 0:
                    integral -= error * dt

                if output == 440:
                    integral -= error * dt

                # send flow command
                send_and_print(f"DC,{output:.2f},{Ts}", wait=Ts)

            # log data
            t = time.time() - start_time
            with open(log_filename, "a", newline="") as log_file:
                writer = csv.writer(log_file)
                if error >= 100:
                    writer.writerow([datetime.now().isoformat(), f"{level_raw:.3f}", f"{level_ema:.3f}", f"{error:.3f}", "480"])
                else:
                    writer.writerow([datetime.now().isoformat(), f"{level_raw:.3f}", f"{level_ema:.3f}", f"{error:.3f}", f"{output:.3f}"])

            # update plot
            times.append(t)
            levels_ema.append(level_ema)
            levels_raw.append(level_raw)
            setpoints.append(target)

            line_meas.set_data(times, levels_ema)
            line_raw.set_data(times, levels_raw)
            line_set.set_data(times, setpoints)
            ax.relim()
            ax.autoscale_view()
            plt.pause(0.01)

            if error >= 100:
                print(f"Level(EMA): {level_ema:7.2f} mm | Level(raw): {level_raw:7.2f} mm | Error: {error:7.2f} | Output: 440 mL/min | Integral: 0 | Derivative: 0")
            else:
                print(f"Level(EMA): {level_ema:7.2f} mm | Level(raw): {level_raw:7.2f} mm | Error: {error:7.2f} | Output: {output:7.2f} mL/min | Integral: {integral:7.2f} | Derivative: {derivative:7.2f}")

    except KeyboardInterrupt:
        print("\n Stopping pump and closing plot...")
        send_and_print("x")
        plt.ioff()
        plt.show()
        print(f" Data logged to {log_filename}")


if __name__ == "__main__":
    main()
