#!/usr/bin/env python3
"""Demonstrate reading an AkuSense MLD25 sensor over Modbus RTU.

The MLD25 manual lists the default RS-485 settings as Modbus address 0x01 and
115200 baud. A Waveshare USB-to-RS485 converter usually appears on Linux as
`/dev/ttyUSB0`, though `--port` may be used to select any tty device.

Usage:
    ./mld25_linux_demo.py --port /dev/ttyUSB0

The script prints device identity once, then refreshes the measured distance on
one stdout line until any key is pressed.
"""

from __future__ import annotations

import argparse
from dataclasses import dataclass
import glob
import os
import select
import sys
import termios
import time
import tty
from types import TracebackType
from typing import Self

k_modbus_function_read_holding_registers = 0x03
k_default_slave_address = 0x01
k_default_baud = 115200
k_default_timeout_seconds = 0.5
k_default_poll_interval_seconds = 0.1
k_default_measurement_scale = 10000.0

# The manual's measurement example is represented by these constants.
k_measurement_register = 0x2512
k_measurement_register_count = 2

# The manual calls these values "Deposit Device (low)", but they are the low
# byte of AkuSense's table address, not a Modbus term.
k_hardware_version_manual_address_low_byte = 0x1F
k_software_version_manual_address_low_byte = 0x20
k_product_number_manual_address_low_byte = 0x21
k_product_serial_number_manual_address_low_byte = 0x2A
k_product_number_register_count = 9
k_product_serial_number_register_count = 12


@dataclass(frozen=True)
class Akusense_Table_Register:
    manual_address_low_byte: int
    register: int
    register_count: int


@dataclass(frozen=True)
class Identity_Registers:
    hardware_version: Akusense_Table_Register
    software_version: Akusense_Table_Register
    product_number: Akusense_Table_Register
    product_serial_number: Akusense_Table_Register


# Maps the manual's table address low byte to the Modbus register range that
# the device actually accepts on the wire.
k_manual_address_to_modbus_registers = (
    Akusense_Table_Register(
        manual_address_low_byte=0x00,
        register=0x2511,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x01,
        register=0x2512,
        register_count=2,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x02,
        register=0x2514,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x03,
        register=0x2515,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x04,
        register=0x2516,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x05,
        register=0x2517,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x06,
        register=0x2518,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x07,
        register=0x2519,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x08,
        register=0x251A,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x09,
        register=0x251B,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x0A,
        register=0x251C,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x0B,
        register=0x251D,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x0C,
        register=0x251E,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x0D,
        register=0x251F,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x0E,
        register=0x2520,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x0F,
        register=0x2521,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x10,
        register=0x2522,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x11,
        register=0x2523,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x12,
        register=0x2524,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x15,
        register=0x2525,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x16,
        register=0x2526,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x18,
        register=0x2527,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x19,
        register=0x2528,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x1A,
        register=0x2529,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x1F,
        register=0x2542,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x20,
        register=0x2543,
        register_count=1,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x21,
        register=0x2544,
        register_count=9,
    ),
    Akusense_Table_Register(
        manual_address_low_byte=0x2A,
        register=0x254D,
        register_count=12,
    ),
)

k_tty_candidates = (
    "/dev/ttyUSB*",
    "/dev/ttyACM*",
)


class Mld25_Error(Exception):
    """Raised for user-visible serial or Modbus failures."""


@dataclass(frozen=True)
class Device_Info:
    hardware_version: str
    software_version: str
    product_number: str
    product_serial_number: str


def build_baud_table() -> dict[int, int]:
    baud_table: dict[int, int] = {}
    for baud in (
        9600,
        19200,
        38400,
        57600,
        115200,
        128000,
        230400,
        256000,
        460800,
        500000,
        512000,
        576000,
        600000,
        750000,
        921600,
        1000000,
        1152000,
        1250000,
    ):
        termios_name = f"B{baud}"
        if hasattr(termios, termios_name):
            baud_table[baud] = getattr(termios, termios_name)
    return baud_table


k_baud_table = build_baud_table()


def register_from_manual_address_low_byte(
    manual_address_low_byte: int,
) -> Akusense_Table_Register:
    for register_block in k_manual_address_to_modbus_registers:
        if register_block.manual_address_low_byte == manual_address_low_byte:
            return register_block
    raise Mld25_Error(
        f"Unknown MLD25 manual address low byte 0x{manual_address_low_byte:02X}"
    )


def identity_registers() -> Identity_Registers:
    return Identity_Registers(
        hardware_version=register_from_manual_address_low_byte(
            k_hardware_version_manual_address_low_byte
        ),
        software_version=register_from_manual_address_low_byte(
            k_software_version_manual_address_low_byte
        ),
        product_number=register_from_manual_address_low_byte(
            k_product_number_manual_address_low_byte
        ),
        product_serial_number=register_from_manual_address_low_byte(
            k_product_serial_number_manual_address_low_byte
        ),
    )


def modbus_crc(frame: bytes) -> int:
    crc = 0xFFFF
    for byte in frame:
        crc ^= byte
        for _ in range(8):
            if crc & 0x0001:
                crc = (crc >> 1) ^ 0xA001
                continue
            crc >>= 1
    return crc & 0xFFFF


def append_modbus_crc(frame: bytes) -> bytes:
    crc = modbus_crc(frame)
    # Modbus RTU transmits the low CRC byte first. The MLD25 manual labels
    # these two bytes as high/low in examples, but the bytes match RTU order.
    return frame + bytes((crc & 0xFF, crc >> 8))


def verify_modbus_crc(frame: bytes) -> bool:
    if len(frame) < 3:
        return False
    payload = frame[:-2]
    expected_crc = int.from_bytes(frame[-2:], byteorder="little")
    return modbus_crc(payload) == expected_crc


def parse_read_holding_registers_response(
    frame: bytes,
    slave_address: int,
    register_count: int,
) -> bytes:
    if len(frame) < 5:
        raise Mld25_Error(f"Short Modbus response: {frame.hex(' ')}")
    if not verify_modbus_crc(frame):
        raise Mld25_Error(f"Bad Modbus CRC in response: {frame.hex(' ')}")
    if frame[0] != slave_address:
        raise Mld25_Error(
            f"Unexpected Modbus address 0x{frame[0]:02X}; expected 0x{slave_address:02X}"
        )

    function_code = frame[1]
    if function_code == (k_modbus_function_read_holding_registers | 0x80):
        raise Mld25_Error(f"Device returned Modbus exception 0x{frame[2]:02X}")
    if function_code != k_modbus_function_read_holding_registers:
        raise Mld25_Error(f"Unexpected Modbus function 0x{function_code:02X}")

    expected_byte_count = register_count * 2
    byte_count = frame[2]
    if byte_count != expected_byte_count:
        raise Mld25_Error(
            f"Unexpected byte count {byte_count}; expected {expected_byte_count}"
        )
    expected_frame_length = 3 + byte_count + 2
    if len(frame) != expected_frame_length:
        raise Mld25_Error(
            f"Unexpected response length {len(frame)}; expected {expected_frame_length}"
        )
    return frame[3 : 3 + byte_count]


def format_version(raw_version: int) -> str:
    major = (raw_version >> 8) & 0xFF
    minor = raw_version & 0xFF
    return f"V{major:02d}.{minor:02d}"


def decode_ascii_registers(register_bytes: bytes) -> str:
    stripped = register_bytes.rstrip(b"\x00\xff ")
    if not stripped:
        return ""
    try:
        return stripped.decode("ascii")
    except UnicodeDecodeError as exc:
        raise Mld25_Error(
            f"Identity field is not ASCII: {register_bytes.hex(' ')}"
        ) from exc


def decode_measurement_mm(register_bytes: bytes, scale: float) -> float:
    if len(register_bytes) != 4:
        raise Mld25_Error(f"Measurement response has {len(register_bytes)} bytes")
    raw_measurement = int.from_bytes(register_bytes, byteorder="big", signed=False)
    return raw_measurement / scale


class Serial_Port:
    def __init__(self, path: str, baud: int, timeout_seconds: float):
        self.path = path
        self.baud = baud
        self.timeout_seconds = timeout_seconds
        self.fd: int | None = None
        self.original_attrs: list[object] | None = None

    def __enter__(self) -> Self:
        try:
            self.fd = os.open(self.path, os.O_RDWR | os.O_NOCTTY | os.O_NONBLOCK)
            self.original_attrs = termios.tcgetattr(self.fd)
            self.configure()
            return self
        except OSError as exc:
            if self.fd is not None:
                os.close(self.fd)
                self.fd = None
            raise Mld25_Error(
                f"Could not open or configure {self.path}: {exc}"
            ) from exc

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        if self.fd is None:
            return
        if self.original_attrs is not None:
            termios.tcsetattr(self.fd, termios.TCSANOW, self.original_attrs)
        os.close(self.fd)
        self.fd = None

    def configure(self) -> None:
        if self.fd is None:
            raise Mld25_Error("Serial port is not open")
        if self.baud not in k_baud_table:
            supported_bauds = ", ".join(str(baud) for baud in sorted(k_baud_table))
            raise Mld25_Error(
                f"Baud {self.baud} is not supported by this OS. Supported: {supported_bauds}"
            )

        attrs = termios.tcgetattr(self.fd)
        attrs[0] = 0
        attrs[1] = 0
        attrs[2] = (
            k_baud_table[self.baud] | termios.CLOCAL | termios.CREAD | termios.CS8
        )
        attrs[3] = 0
        attrs[4] = k_baud_table[self.baud]
        attrs[5] = k_baud_table[self.baud]
        attrs[6][termios.VMIN] = 0
        attrs[6][termios.VTIME] = 0
        termios.tcsetattr(self.fd, termios.TCSANOW, attrs)
        termios.tcflush(self.fd, termios.TCIOFLUSH)

    def flush_input(self) -> None:
        if self.fd is None:
            raise Mld25_Error("Serial port is not open")
        termios.tcflush(self.fd, termios.TCIFLUSH)

    def write_all(self, data: bytes) -> None:
        if self.fd is None:
            raise Mld25_Error("Serial port is not open")
        total_written = 0
        while total_written < len(data):
            ready_to_write = select.select([], [self.fd], [], self.timeout_seconds)[1]
            if not ready_to_write:
                raise Mld25_Error(f"Timed out writing to {self.path}")
            try:
                total_written += os.write(self.fd, data[total_written:])
            except OSError as exc:
                raise Mld25_Error(f"Could not write to {self.path}: {exc}") from exc

    def read_exact(self, byte_count: int) -> bytes:
        if self.fd is None:
            raise Mld25_Error("Serial port is not open")
        deadline = time.monotonic() + self.timeout_seconds
        chunks: list[bytes] = []
        remaining = byte_count
        while remaining > 0:
            timeout = deadline - time.monotonic()
            if timeout <= 0:
                received = b"".join(chunks)
                raise Mld25_Error(
                    f"Timed out reading {byte_count} bytes from {self.path}; "
                    f"received {len(received)} bytes: {received.hex(' ')}"
                )
            ready_to_read = select.select([self.fd], [], [], timeout)[0]
            if not ready_to_read:
                continue
            try:
                chunk = os.read(self.fd, remaining)
            except OSError as exc:
                raise Mld25_Error(f"Could not read from {self.path}: {exc}") from exc
            if not chunk:
                continue
            chunks.append(chunk)
            remaining -= len(chunk)
        return b"".join(chunks)


class Mld25_Client:
    def __init__(self, serial_port: Serial_Port, slave_address: int, debug: bool):
        self.serial_port = serial_port
        self.slave_address = slave_address
        self.debug = debug

    def read_holding_registers(self, start_register: int, register_count: int) -> bytes:
        request = bytes(
            (
                self.slave_address,
                k_modbus_function_read_holding_registers,
                (start_register >> 8) & 0xFF,
                start_register & 0xFF,
                (register_count >> 8) & 0xFF,
                register_count & 0xFF,
            )
        )
        request_with_crc = append_modbus_crc(request)
        self.debug_print(f"TX {request_with_crc.hex(' ')}")
        self.serial_port.flush_input()
        self.serial_port.write_all(request_with_crc)

        header = self.serial_port.read_exact(3)
        if header[1] == (k_modbus_function_read_holding_registers | 0x80):
            response = header + self.serial_port.read_exact(2)
            self.debug_print(f"RX {response.hex(' ')}")
            parse_read_holding_registers_response(
                response, self.slave_address, register_count
            )
        data_and_crc = self.serial_port.read_exact(header[2] + 2)
        response = header + data_and_crc
        self.debug_print(f"RX {response.hex(' ')}")
        return parse_read_holding_registers_response(
            response,
            self.slave_address,
            register_count,
        )

    def debug_print(self, message: str) -> None:
        if not self.debug:
            return
        print(f"debug: {message}", file=sys.stderr)

    def read_u16(self, start_register: int) -> int:
        register_bytes = self.read_holding_registers(start_register, 1)
        return int.from_bytes(register_bytes, byteorder="big", signed=False)

    def read_device_info(self) -> Device_Info:
        registers = identity_registers()

        hardware_version = format_version(
            self.read_u16(registers.hardware_version.register)
        )
        software_version = format_version(
            self.read_u16(registers.software_version.register)
        )
        product_number = decode_ascii_registers(
            self.read_holding_registers(
                registers.product_number.register,
                registers.product_number.register_count,
            )
        )
        product_serial_number = decode_ascii_registers(
            self.read_holding_registers(
                registers.product_serial_number.register,
                registers.product_serial_number.register_count,
            )
        )
        return Device_Info(
            hardware_version=hardware_version,
            software_version=software_version,
            product_number=product_number,
            product_serial_number=product_serial_number,
        )

    def read_measurement_mm(self, scale: float) -> float:
        register_bytes = self.read_holding_registers(
            k_measurement_register,
            k_measurement_register_count,
        )
        return decode_measurement_mm(register_bytes, scale)


class Terminal_Key_Stopper:
    def __init__(self) -> None:
        self.fd = sys.stdin.fileno()
        self.original_attrs: list[object] | None = None

    def __enter__(self) -> Self:
        if not sys.stdin.isatty():
            raise Mld25_Error("stdin is not a tty; cannot stop on any key")
        self.original_attrs = termios.tcgetattr(self.fd)
        tty.setcbreak(self.fd)
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        if self.original_attrs is not None:
            termios.tcsetattr(self.fd, termios.TCSADRAIN, self.original_attrs)

    def key_was_pressed(self) -> bool:
        ready_to_read = select.select([sys.stdin], [], [], 0)[0]
        if not ready_to_read:
            return False
        os.read(self.fd, 1)
        return True


def detect_single_serial_port() -> str:
    candidates: list[str] = []
    for pattern in k_tty_candidates:
        candidates.extend(glob.glob(pattern))
    unique_candidates = sorted(set(candidates))

    if len(unique_candidates) == 1:
        return unique_candidates[0]
    if not unique_candidates:
        raise Mld25_Error("No USB serial port found; pass --port /dev/ttyUSB0")
    joined_candidates = ", ".join(unique_candidates)
    raise Mld25_Error(
        f"Multiple USB serial ports found: {joined_candidates}; pass --port"
    )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Read AkuSense MLD25 device info and stream measurement values.",
    )
    parser.add_argument("--port", help="Serial tty path, for example /dev/ttyUSB0")
    parser.add_argument("--baud", type=int, default=k_default_baud)
    parser.add_argument(
        "--address",
        type=lambda value: int(value, 0),
        default=k_default_slave_address,
        help="Modbus slave address, decimal or hex. Default: 0x01",
    )
    parser.add_argument(
        "--interval",
        type=float,
        default=k_default_poll_interval_seconds,
        help="Polling interval in seconds.",
    )
    parser.add_argument(
        "--scale",
        type=float,
        default=k_default_measurement_scale,
        help="Measurement divisor. MLD25-100 values use 10000 by default.",
    )
    parser.add_argument(
        "--timeout",
        type=float,
        default=k_default_timeout_seconds,
        help="Serial read/write timeout in seconds.",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Print raw Modbus request and response bytes to stderr.",
    )
    parser.add_argument(
        "--probe-identity",
        action="store_true",
        help="Probe likely identity register offsets and exit.",
    )
    parser.add_argument(
        "--read-once",
        action="store_true",
        help="Read device info and one measured value, then exit.",
    )
    return parser.parse_args()


def validate_args(args: argparse.Namespace) -> None:
    if not 1 <= args.address <= 247:
        raise Mld25_Error("--address must be in the Modbus range 1..247")
    if args.interval <= 0:
        raise Mld25_Error("--interval must be greater than zero")
    if args.scale <= 0:
        raise Mld25_Error("--scale must be greater than zero")
    if args.timeout <= 0:
        raise Mld25_Error("--timeout must be greater than zero")


def print_device_info(device_info: Device_Info) -> None:
    print("Device info")
    print(f"  Hardware version:      {device_info.hardware_version}")
    print(f"  Software version:      {device_info.software_version}")
    print(f"  Product number:        {device_info.product_number}")
    print(f"  Product serial number: {device_info.product_serial_number}")
    print()


def try_read_register(
    client: Mld25_Client,
    label: str,
    register: int,
    register_count: int,
) -> None:
    try:
        register_bytes = client.read_holding_registers(register, register_count)
    except Mld25_Error as exc:
        print(f"{label:26} 0x{register:04X} x{register_count:<2} error: {exc}")
        return
    print(
        f"{label:26} 0x{register:04X} x{register_count:<2} "
        f"bytes: {register_bytes.hex(' ')}"
    )


def probe_identity_registers(client: Mld25_Client) -> None:
    registers = identity_registers()
    try_read_register(
        client,
        "hardware version",
        registers.hardware_version.register,
        registers.hardware_version.register_count,
    )
    try_read_register(
        client,
        "software version",
        registers.software_version.register,
        registers.software_version.register_count,
    )
    try_read_register(
        client,
        "product number",
        registers.product_number.register,
        registers.product_number.register_count,
    )
    try_read_register(
        client,
        "product serial number",
        registers.product_serial_number.register,
        registers.product_serial_number.register_count,
    )


def stream_measurements(client: Mld25_Client, scale: float, interval: float) -> None:
    print("Press any key to stop.")
    with Terminal_Key_Stopper() as stopper:
        while not stopper.key_was_pressed():
            measurement_mm = client.read_measurement_mm(scale)
            print(f"\rMeasured value: {measurement_mm:10.4f} mm", end="", flush=True)
            time.sleep(interval)
    print()


def print_single_measurement(client: Mld25_Client, scale: float) -> None:
    measurement_mm = client.read_measurement_mm(scale)
    print(f"Measured value: {measurement_mm:.4f} mm")


def main() -> int:
    args = parse_args()
    try:
        validate_args(args)
        port = args.port or detect_single_serial_port()
        with Serial_Port(port, args.baud, args.timeout) as serial_port:
            client = Mld25_Client(serial_port, args.address, args.debug)
            if args.probe_identity:
                probe_identity_registers(client)
                return 0
            print_device_info(client.read_device_info())
            if args.read_once:
                print_single_measurement(client, args.scale)
                return 0
            stream_measurements(client, args.scale, args.interval)
    except KeyboardInterrupt:
        print()
        return 130
    except Mld25_Error as exc:
        print(f"error: {exc}", file=sys.stderr)
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
