from smbus2 import SMBus, i2c_msg
import time, json, os, smbus2 as smbus, csv
from datetime import datetime
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt

# ------------------ I2C ADDRESSES ------------------
ADS1115_ADDR = 0x48
PUMP_ADDR = 0x6D
CAL_FILE = "etape_calibration.txt"

# ------------------ ADS1115 CONFIG ------------------
ADS1115_CONVERSION = 0x00
ADS1115_CONFIG = 0x01
CONFIG = (
    0x4000 | 0x0200 | 0x0100 | 0x8000 | 0x0003
)

# ------------------ ADS1115 FUNCTIONS ------------------
def read_voltage(bus):
    config_bytes = [(CONFIG >> 8) & 0xFF, CONFIG & 0xFF]
    bus.write_i2c_block_data(ADS1115_ADDR, ADS1115_CONFIG, config_bytes)
    time.sleep(0.3)
    data = bus.read_i2c_block_data(ADS1115_ADDR, ADS1115_CONVERSION, 2)
    raw = (data[0] << 8) | data[1]
    if raw > 32767:
        raw -= 65536
    voltage = raw * 4.096 / 32768.0
    return voltage

def voltage_to_level(v, cal):
    return cal["m"] * v + cal["b"]

def load_calibration():
    if not os.path.exists(CAL_FILE):
        print(" No calibration file found. Run calibration first.")
        exit()
    with open(CAL_FILE, "r") as f:
        return json.load(f)

# ------------------ PUMP FUNCTIONS ------------------
def raw_send_command(cmd):
    data = bytes(cmd, "ascii") + b'\x00'
    msg = i2c_msg.write(PUMP_ADDR, data)
    with SMBus(1) as bus:
        bus.i2c_rdwr(msg)
    time.sleep(0.3)

def send_and_print(cmd, wait=0.5):
    raw_send_command(cmd)
    time.sleep(wait)


# ------------------ MAIN ------------------
def main():
    print("Starting PID Control with Real-Time Plot...\n")
    bus = smbus.SMBus(1)
    cal = load_calibration()
    send_and_print("*OK,1")

    # ---- PI parameters ----
    target = float(input("Enter target level (mm): "))
    Kp = float(input("Enter proportional gain (Kp): "))
    Ki = float(input("Enter integral gain (Ki): "))
    Kd = float(input("Enter derivative gain (Kd): "))
    Ts = 1.0

    # ---- Prepare log and plot ----
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_filename = f"pi_log_{timestamp}.csv"
    with open(log_filename, "w", newline="") as log_file:
        writer = csv.writer(log_file)
        writer.writerow(["Timestamp", "Level_mm", "Error_mm", "PumpOutput_mLmin"])

    print(f" Logging data to {log_filename}")
    print(" Starting real-time plot (close window or Ctrl+C to stop)...")

    plt.ion()
    fig, ax = plt.subplots()
    times, levels, setpoints = [], [], []
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Level (mm)")
    line_meas, = ax.plot([], [], label="Measured Level", color="tab:blue")
    line_set, = ax.plot([], [], label="Setpoint", color="tab:red", linestyle="--")
    ax.legend()
    plt.show(block=False)

    # ---- PI variables ----
    integral = 0.0
    last_error = 0.0
    output = 0.0
    derivative = 0.0
    start_time = time.time()
    last_time = start_time
    pump = True
    pump2 = True

    try:
        while True:
            v = read_voltage(bus)
            level = voltage_to_level(v, cal)
            error = target - level
            dt = time.time() - last_time
            last_time = time.time()
            
            
            if error <= 0:
                send_and_print("x")
                integral = 0
                derivative = 0
               
                output = 0
                
            if error >= 100:
                send_and_print("dc,440")
                pump = True
                
            else:
                if pump == True :
                    send_and_print("x")
                    pump = False
                
                integral += error * dt
                derivative = (error - last_error)/ dt if dt > 0 else 0.0
                last_error = error
                output = Kp * error + Ki * integral + Kd * derivative
                output = max(0, min(output, 440))
                
                if output > 0 and output < 10:
                    output = 0.001
                    
                if output == 0:
                    integral -= error * dt
                
           
                if output == 440:
                    integral -= error * dt
                
                

                # send flow command
                send_and_print(f"DC,{output:.2f},{Ts}", wait=Ts)

             # log data
            t = time.time() - start_time
            with open(log_filename, "a", newline="") as log_file:
                writer = csv.writer(log_file)
                if error >= 100:
                    writer.writerow([datetime.now().isoformat(), f"{level:.3f}", f"{error:.3f}", "480"])
                
                else:
                    writer.writerow([datetime.now().isoformat(), f"{level:.3f}", f"{error:.3f}", f"{output:.3f}"])
    
            # update plot
            times.append(t)
            levels.append(level)
            setpoints.append(target)
            line_meas.set_data(times, levels)
            line_set.set_data(times, setpoints)
            ax.relim()
            ax.autoscale_view()
            plt.pause(0.01)
            
            if error >= 100:
                print(f"Level: {level:7.2f} mm | Error: {error:7.2f} | Output: 440 mL/min | Integral: 0 | Derivative: 0 " )
           
            else:
                print(f"Level: {level:7.2f} mm | Error: {error:7.2f} | Output: {output:7.2f} mL/min | Integral: {integral:7.2f} | Derivative: {derivative:7.2f} " )



    except KeyboardInterrupt:
        print("\n Stopping pump and closing plot...")
        send_and_print("x")
        plt.ioff()
        plt.show()
        print(f" Data logged to {log_filename}")


if __name__ == "__main__":
    main()
    
