#!/usr/bin/env python3
"""
pendant.py — Python HAL component for your mill pendant.

Notes:
 - This component exposes HAL pins that you should net to your Mesa inputs
   (or whatever IO board you use). In your HAL files use the [HMOT](CARD0)
   label when referring to the physical card, e.g.:

     net pendant-encoderA    pendant.encoderA    [HMOT](CARD0).gpio.016.in

 - This Python code uses the bracket style to access pins:
       c["encoderA"].get()   -> read an input pin
       c["spindle_speed"].set(value) -> set an output pin

Pins created by this component:

Inputs (hardware -> net these to [HMOT](CARD0) GPIO pins):
  - mode_sw           (S32)   hardware mode switch (0=auto,1=manual,2=mdi)
  - momentary         (BIT)   momentary multi-press button
  - encoderA          (BIT)   spindle encoder A (CLK)
  - encoderB          (BIT)   spindle encoder B (DT)
  - encoderBtn        (BIT)   spindle encoder push (direction toggle)
  - jogEncA           (BIT)   jog encoder A
  - jogEncB           (BIT)   jog encoder B
  - jogEncBtn         (BIT)   jog encoder push
  - dpad_up/down/left/right (BIT)
  - z_y_select        (BIT)   Z/Y selector (1 => Z, 0 => Y)
  - axis_x_sel/y_sel/z_sel (BIT)
  - coolant_hw        (BIT)

Outputs / events (created by this component; net to other HAL components/UI):
  - mode              (S32)
  - single_press      (BIT)
  - double_press      (BIT)
  - triple_press      (BIT)
  - spindle_speed     (FLOAT)
  - spindle_dir       (BIT)
  - jog_speed_idx     (S32)
  - jog_cw / jog_ccw  (BIT)
  - jog_up / jog_down (BIT)
  - coolant_toggle    (BIT)
  - probe_request     (BIT)
  - set_zero_request  (BIT)
  - ref_request       (BIT)
"""

import hal
import time

# -------------------------
# Config
# -------------------------
LOOP_HZ = 100.0
LOOP_DT = 1.0 / LOOP_HZ

SPINDLE_ENCODER_TICKS_PER_REV = 24
SPINDLE_SPEED_MIN = 0.0
SPINDLE_SPEED_MAX = 1.0

JOG_SPEED_PRESETS = [0.1, 0.2, 0.5, 1.0]

SINGLE_MAX_MS = 300
TRIPLE_MAX_MS = 900
EVENT_PULSE_TIME = 0.08

# -------------------------
# Rotary encoder helper
# -------------------------
class RotaryEncoder:
    def __init__(self):
        self.last_a = 0
        self.last_b = 0
        self.pos = 0

    def update(self, a, b):
        if a == self.last_a and b == self.last_b:
            return 0
        delta = 0
        if self.last_a == 0 and a == 1:
            delta = -1 if b else 1
        elif self.last_a == 1 and a == 0:
            delta = 1 if self.last_b else -1
        elif self.last_b == 0 and b == 1:
            delta = 1 if a else -1
        elif self.last_b == 1 and b == 0:
            delta = -1 if self.last_a else 1
        self.last_a = a
        self.last_b = b
        self.pos += delta
        return delta

# -------------------------
# HAL component
# -------------------------
c = hal.component("pendant")

# -- inputs (hardware)
c.newpin("mode_sw", hal.HAL_S32, hal.HAL_IN)      # hardware mode switch
c.newpin("momentary", hal.HAL_BIT, hal.HAL_IN)
c.newpin("encoderA", hal.HAL_BIT, hal.HAL_IN)
c.newpin("encoderB", hal.HAL_BIT, hal.HAL_IN)
c.newpin("encoderBtn", hal.HAL_BIT, hal.HAL_IN)
c.newpin("jogEncA", hal.HAL_BIT, hal.HAL_IN)
c.newpin("jogEncB", hal.HAL_BIT, hal.HAL_IN)
c.newpin("jogEncBtn", hal.HAL_BIT, hal.HAL_IN)
c.newpin("dpad_up", hal.HAL_BIT, hal.HAL_IN)
c.newpin("dpad_down", hal.HAL_BIT, hal.HAL_IN)
c.newpin("dpad_left", hal.HAL_BIT, hal.HAL_IN)
c.newpin("dpad_right", hal.HAL_BIT, hal.HAL_IN)
c.newpin("z_y_select", hal.HAL_BIT, hal.HAL_IN)
c.newpin("axis_x_sel", hal.HAL_BIT, hal.HAL_IN)
c.newpin("axis_y_sel", hal.HAL_BIT, hal.HAL_IN)
c.newpin("axis_z_sel", hal.HAL_BIT, hal.HAL_IN)
c.newpin("coolant_hw", hal.HAL_BIT, hal.HAL_IN)

# -- outputs / events
c.newpin("mode", hal.HAL_S32, hal.HAL_OUT)
c.newpin("single_press", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("double_press", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("triple_press", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("spindle_speed", hal.HAL_FLOAT, hal.HAL_OUT)
c.newpin("spindle_dir", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("jog_speed_idx", hal.HAL_S32, hal.HAL_OUT)
c.newpin("jog_cw", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("jog_ccw", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("jog_up", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("jog_down", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("coolant_toggle", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("probe_request", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("set_zero_request", hal.HAL_BIT, hal.HAL_OUT)
c.newpin("ref_request", hal.HAL_BIT, hal.HAL_OUT)

c.ready()

# -------------------------
# State
# -------------------------
spindle_enc = RotaryEncoder()
jog_enc = RotaryEncoder()
spindle_ticks = 0
jog_idx = 0
spindle_dir = 0
coolant_state = 0

press_count = 0
last_press_time = 0.0
waiting_for_more = False

event_pulses = {}

# -------------------------
# Helpers
# -------------------------
def pulse(name):
    """Set an event pin high for EVENT_PULSE_TIME seconds (non-blocking)."""
    now = time.time()
    event_pulses[name] = now + EVENT_PULSE_TIME
    c[name].set(1)

def update_pulses(now):
    for name, expiry in list(event_pulses.items()):
        if now >= expiry:
            c[name].set(0)
            del event_pulses[name]

def handle_momentary_press(mode):
    global press_count, last_press_time, waiting_for_more
    now = time.time()
    press_count += 1
    last_press_time = now
    waiting_for_more = True

def dispatch_press_event(count, mode):
    # Map multi-press to actions. In MANUAL (1) we assign:
    # single -> set-zero, double -> probe request, triple -> reference request
    if mode == 1:  # MANUAL
        if count == 1:
            pulse("single_press")
            c["set_zero_request"].set(1)
        elif count == 2:
            pulse("double_press")
            c["probe_request"].set(1)
        else:
            pulse("triple_press")
            c["ref_request"].set(1)
    else:
        # In AUTO/MDI just emit press events for other logic to handle
        if count == 1:
            pulse("single_press")
        elif count == 2:
            pulse("double_press")
        else:
            pulse("triple_press")

def clear_request_flags():
    # Clear these quickly so they behave like pulses
    c["probe_request"].set(0)
    c["set_zero_request"].set(0)
    c["ref_request"].set(0)

# -------------------------
# Main loop
# -------------------------
try:
    # read initial values (use get())
    prev_enc_a = int(c["encoderA"].get())
    prev_enc_b = int(c["encoderB"].get())
    prev_jogA = int(c["jogEncA"].get())
    prev_jogB = int(c["jogEncB"].get())
    prev_enc_btn = int(c["encoderBtn"].get())
    prev_jog_btn = int(c["jogEncBtn"].get())
    prev_momentary = int(c["momentary"].get())

    # initialize outputs
    c["mode"].set(0)
    c["spindle_speed"].set(0.0)
    c["spindle_dir"].set(0)
    c["jog_speed_idx"].set(jog_idx)
    c["coolant_toggle"].set(coolant_state)
    c["jog_cw"].set(0)
    c["jog_ccw"].set(0)
    c["jog_up"].set(0)
    c["jog_down"].set(0)

    while True:
        now = time.time()

        # -- mode (read hardware mode switch)
        mode_val = int(c["mode_sw"].get())  # 0/1/2
        c["mode"].set(mode_val)

        # -- momentary multi-press detection (edge)
        m = int(c["momentary"].get())
        if m != prev_momentary:
            if m == 1:
                handle_momentary_press(mode_val)
            prev_momentary = m

        if waiting_for_more:
            elapsed = (now - last_press_time) * 1000.0
            if elapsed >= SINGLE_MAX_MS:
                cnt = min(press_count, 3)
                dispatch_press_event(cnt, mode_val)
                press_count = 0
                waiting_for_more = False

        # Clear request flags so they are pulse-like
        clear_request_flags()

        # -- spindle encoder
        a = int(c["encoderA"].get())
        b = int(c["encoderB"].get())
        d = spindle_enc.update(a, b)
        if d != 0:
            spindle_ticks += d
            speed = (spindle_ticks / SPINDLE_ENCODER_TICKS_PER_REV) * 0.05
            if speed < SPINDLE_SPEED_MIN:
                speed = SPINDLE_SPEED_MIN
            if speed > SPINDLE_SPEED_MAX:
                speed = SPINDLE_SPEED_MAX
            c["spindle_speed"].set(float(speed))

        enc_btn = int(c["encoderBtn"].get())
        if enc_btn != prev_enc_btn:
            if enc_btn == 1:
                spindle_dir = 1 - spindle_dir
                c["spindle_dir"].set(spindle_dir)
            prev_enc_btn = enc_btn

        # -- jog encoder -> jog speed index
        ja = int(c["jogEncA"].get())
        jb = int(c["jogEncB"].get())
        dd = jog_enc.update(ja, jb)
        if dd != 0:
            if dd > 0:
                jog_idx = max(0, min(len(JOG_SPEED_PRESETS) - 1, jog_idx + 1))
            else:
                jog_idx = max(0, min(len(JOG_SPEED_PRESETS) - 1, jog_idx - 1))
            c["jog_speed_idx"].set(jog_idx)

        jog_btn = int(c["jogEncBtn"].get())
        if jog_btn != prev_jog_btn:
            if jog_btn == 1:
                jog_idx = (jog_idx + 1) % len(JOG_SPEED_PRESETS)
                c["jog_speed_idx"].set(jog_idx)
            prev_jog_btn = jog_btn

        # -- d-pad -> jog signals
        up = int(c["dpad_up"].get())
        down = int(c["dpad_down"].get())
        left = int(c["dpad_left"].get())
        right = int(c["dpad_right"].get())

        c["jog_up"].set(0)
        c["jog_down"].set(0)
        c["jog_cw"].set(0)
        c["jog_ccw"].set(0)

        if up:
            c["jog_up"].set(1)
        if down:
            c["jog_down"].set(1)
        if right:
            c["jog_cw"].set(1)
        if left:
            c["jog_ccw"].set(1)

        # -- coolant toggle (treat hardware switch as toggle)
        hw_cool = int(c["coolant_hw"].get())
        if hw_cool != coolant_state:
            coolant_state = hw_cool
            c["coolant_toggle"].set(coolant_state)

        # -- housekeeping: update pulses
        update_pulses(now)

        # sleep until next cycle
        dt = LOOP_DT - (time.time() - now)
        if dt > 0:
            time.sleep(dt)

except KeyboardInterrupt:
    # clear outputs and halt the component
    c["single_press"].set(0)
    c["double_press"].set(0)
    c["triple_press"].set(0)
    c["spindle_speed"].set(0.0)
    c["spindle_dir"].set(0)
    c["jog_speed_idx"].set(0)
    c["jog_up"].set(0)
    c["jog_down"].set(0)
    c.halt()
