diff --git a/pyomnilogic_local/api/__init__.py b/pyomnilogic_local/api/__init__.py index b9f9ff1..928fccb 100644 --- a/pyomnilogic_local/api/__init__.py +++ b/pyomnilogic_local/api/__init__.py @@ -8,7 +8,29 @@ from __future__ import annotations from .api import OmniLogicAPI +from .exceptions import ( + OmniCommandError, + OmniConnectionError, + OmniFragmentationError, + OmniLogicError, + OmniMessageFormatError, + OmniProtocolError, + OmniTimeoutError, + OmniValidationError, +) +from .message import OmniLogicMessage +from .protocol import OmniLogicProtocol __all__ = [ + "OmniCommandError", + "OmniConnectionError", + "OmniFragmentationError", "OmniLogicAPI", + "OmniLogicError", + "OmniLogicMessage", + "OmniLogicProtocol", + "OmniMessageFormatError", + "OmniProtocolError", + "OmniTimeoutError", + "OmniValidationError", ] diff --git a/pyomnilogic_local/api/api.py b/pyomnilogic_local/api/api.py index 6e6f399..8834de0 100644 --- a/pyomnilogic_local/api/api.py +++ b/pyomnilogic_local/api/api.py @@ -111,30 +111,39 @@ def __init__( self.controller_port = controller_port self.response_timeout = response_timeout - @overload - async def async_send_message(self, message_type: MessageType, message: str | None, need_response: Literal[True]) -> str: ... - @overload - async def async_send_message(self, message_type: MessageType, message: str | None, need_response: Literal[False]) -> None: ... - async def async_send_message(self, message_type: MessageType, message: str | None, need_response: bool = False) -> str | None: + async def async_send(self, message_type: MessageType, message: str) -> None: """Send a message via the Hayward Omni UDP protocol along with properly handling timeouts and responses. Args: message_type (MessageType): A selection from MessageType indicating what type of communication you are sending - message (str | None): The XML body of the message to deliver - need_response (bool, optional): Should a response be received and returned to the caller. Defaults to False. + message (str): The XML body of the message to deliver + + Returns: + None + """ + loop = asyncio.get_running_loop() + transport, protocol = await loop.create_datagram_endpoint(OmniLogicProtocol, remote_addr=(self.controller_ip, self.controller_port)) + + try: + await protocol.async_send(message_type, message) + finally: + transport.close() + + async def async_send_and_receive(self, message_type: MessageType, message: str) -> str: + """Convenience method to send a message and receive a response without needing to specify need_response every time. + + Args: + message_type (MessageType): A selection from MessageType indicating what type of communication you are sending + message (str): The message payload to send. Returns: - str | None: The response body sent from the Omni if need_response indicates that a response will be sent + str: The response body sent from the Omni """ loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint(OmniLogicProtocol, remote_addr=(self.controller_ip, self.controller_port)) - resp: str | None = None try: - if need_response: - resp = await protocol.send_and_receive(message_type, message, response_timeout=self.response_timeout) - else: - await protocol.send_message(message_type, message) + resp = await protocol.async_send_and_receive(message_type, message) finally: transport.close() @@ -164,7 +173,7 @@ async def async_get_mspconfig(self, raw: bool = False) -> MSPConfig | str: _LOGGER.debug("Sending RequestConfiguration with body: %s", req_body) - resp = await self.async_send_message(MessageType.REQUEST_CONFIGURATION, req_body, True) + resp = await self.async_send_and_receive(MessageType.REQUEST_CONFIGURATION, req_body) _LOGGER.debug("Received response for RequestConfiguration: %s", resp) @@ -206,7 +215,7 @@ async def async_get_filter_diagnostics(self, pool_id: int, equipment_id: int, ra _LOGGER.debug("Sending GetUIFilterDiagnosticInfo with body: %s", req_body) - resp = await self.async_send_message(MessageType.GET_FILTER_DIAGNOSTIC_INFO, req_body, True) + resp = await self.async_send_and_receive(MessageType.GET_FILTER_DIAGNOSTIC_INFO, req_body) _LOGGER.debug("Received response for GetUIFilterDiagnosticInfo: %s", resp) @@ -235,7 +244,7 @@ async def async_get_telemetry(self, raw: bool = False) -> Telemetry | str: _LOGGER.debug("Sending RequestTelemetryData with body: %s", req_body) - resp = await self.async_send_message(MessageType.GET_TELEMETRY, req_body, True) + resp = await self.async_send_and_receive(MessageType.GET_TELEMETRY, req_body) _LOGGER.debug("Received response for RequestTelemetryData: %s", resp) @@ -276,7 +285,7 @@ async def async_set_heater( _LOGGER.debug("Sending SetUIHeaterCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_HEATER_COMMAND, req_body, False) + return await self.async_send(MessageType.SET_HEATER_COMMAND, req_body) async def async_set_solar_heater( self, @@ -311,7 +320,7 @@ async def async_set_solar_heater( _LOGGER.debug("Sending SetUISolarSetPointCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_SOLAR_SET_POINT_COMMAND, req_body, False) + return await self.async_send(MessageType.SET_SOLAR_SET_POINT_COMMAND, req_body) async def async_set_heater_mode( self, @@ -346,7 +355,7 @@ async def async_set_heater_mode( _LOGGER.debug("Sending SetUIHeaterModeCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_HEATER_MODE_COMMAND, req_body, False) + return await self.async_send(MessageType.SET_HEATER_MODE_COMMAND, req_body) async def async_set_heater_enable( self, @@ -381,7 +390,7 @@ async def async_set_heater_enable( _LOGGER.debug("Sending SetHeaterEnable with body: %s", req_body) - return await self.async_send_message(MessageType.SET_HEATER_ENABLED, req_body, False) + return await self.async_send(MessageType.SET_HEATER_ENABLED, req_body) async def async_set_equipment( self, @@ -443,7 +452,7 @@ async def async_set_equipment( _LOGGER.debug("Sending SetUIEquipmentCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_EQUIPMENT, req_body, False) + return await self.async_send(MessageType.SET_EQUIPMENT, req_body) async def async_set_filter_speed(self, pool_id: int, equipment_id: int, speed: int) -> None: """Set the speed for a variable speed filter/pump. @@ -471,7 +480,7 @@ async def async_set_filter_speed(self, pool_id: int, equipment_id: int, speed: i _LOGGER.debug("Sending SetUIFilterSpeedCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_FILTER_SPEED, req_body, False) + return await self.async_send(MessageType.SET_FILTER_SPEED, req_body) async def async_set_light_show( self, @@ -543,7 +552,7 @@ async def async_set_light_show( _LOGGER.debug("Sending SetStandAloneLightShow with body: %s", req_body) - return await self.async_send_message(MessageType.SET_STANDALONE_LIGHT_SHOW, req_body, False) + return await self.async_send(MessageType.SET_STANDALONE_LIGHT_SHOW, req_body) async def async_set_chlorinator_enable(self, pool_id: int, enabled: int | bool) -> None: body_element = ET.Element("Request", {"xmlns": XML_NAMESPACE}) @@ -561,7 +570,7 @@ async def async_set_chlorinator_enable(self, pool_id: int, enabled: int | bool) _LOGGER.debug("Sending SetCHLOREnable with body: %s", req_body) - return await self.async_send_message(MessageType.SET_CHLOR_ENABLED, req_body, False) + return await self.async_send(MessageType.SET_CHLOR_ENABLED, req_body) # This is used to set the ORP target value on a CSAD async def async_set_csad_orp_target_level( @@ -587,7 +596,7 @@ async def async_set_csad_orp_target_level( _LOGGER.debug("Sending SetUICSADORPTargetLevel with body: %s", req_body) - return await self.async_send_message(MessageType.SET_CSAD_ORP_TARGET, req_body, False) + return await self.async_send(MessageType.SET_CSAD_ORP_TARGET, req_body) # This is used to set the pH target value on a CSAD async def async_set_csad_target_value( @@ -613,7 +622,7 @@ async def async_set_csad_target_value( _LOGGER.debug("Sending UISetCSADTargetValue with body: %s", req_body) - return await self.async_send_message(MessageType.SET_CSAD_TARGET_VALUE, req_body, False) + return await self.async_send(MessageType.SET_CSAD_TARGET_VALUE, req_body) async def async_set_chlorinator_params( self, @@ -656,7 +665,7 @@ async def async_set_chlorinator_params( _LOGGER.debug("Sending SetCHLORParams with body: %s", req_body) - return await self.async_send_message(MessageType.SET_CHLOR_PARAMS, req_body, False) + return await self.async_send(MessageType.SET_CHLOR_PARAMS, req_body) async def async_set_chlorinator_superchlorinate( self, @@ -681,7 +690,7 @@ async def async_set_chlorinator_superchlorinate( _LOGGER.debug("Sending SetUISuperCHLORCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_SUPERCHLORINATE, req_body, False) + return await self.async_send(MessageType.SET_SUPERCHLORINATE, req_body) async def async_restore_idle_state(self) -> None: body_element = ET.Element("Request", {"xmlns": XML_NAMESPACE}) @@ -695,7 +704,7 @@ async def async_restore_idle_state(self) -> None: _LOGGER.debug("Sending RestoreIdleState with body: %s", req_body) - return await self.async_send_message(MessageType.RESTORE_IDLE_STATE, req_body, False) + return await self.async_send(MessageType.RESTORE_IDLE_STATE, req_body) async def async_set_spillover( self, @@ -738,7 +747,7 @@ async def async_set_spillover( _LOGGER.debug("Sending SetUISpilloverCmd with body: %s", req_body) - return await self.async_send_message(MessageType.SET_SPILLOVER, req_body, False) + return await self.async_send(MessageType.SET_SPILLOVER, req_body) async def async_set_group_enable( self, @@ -781,7 +790,7 @@ async def async_set_group_enable( _LOGGER.debug("Sending RunGroupCmd with body: %s", req_body) - return await self.async_send_message(MessageType.RUN_GROUP_CMD, req_body, False) + return await self.async_send(MessageType.RUN_GROUP_CMD, req_body) async def async_edit_schedule( self, @@ -854,4 +863,4 @@ async def async_edit_schedule( _LOGGER.debug("Sending EditUIScheduleCmd with body: %s", req_body) - return await self.async_send_message(MessageType.EDIT_SCHEDULE, req_body, False) + return await self.async_send(MessageType.EDIT_SCHEDULE, req_body) diff --git a/pyomnilogic_local/api/constants.py b/pyomnilogic_local/api/constants.py index 5526a21..25a0626 100644 --- a/pyomnilogic_local/api/constants.py +++ b/pyomnilogic_local/api/constants.py @@ -11,9 +11,9 @@ BLOCK_MESSAGE_HEADER_OFFSET = 8 # Offset to skip block message header and get to payload # Timing Constants (in seconds) -OMNI_RETRANSMIT_TIME = 2.1 # Time Omni waits before retransmitting a packet +OMNI_RETRANSMIT_TIME = 2 # Time Omni waits before retransmitting a packet OMNI_RETRANSMIT_COUNT = 5 # Number of retransmit attempts (6 total including initial) -ACK_WAIT_TIMEOUT = 1 # Timeout waiting for ACK response, 0.5 showed to be just a tad too short in some cases. +ACK_WAIT_TIMEOUT = OMNI_RETRANSMIT_TIME * 2 # Timeout waiting for ACK response, 0.5 showed to be just a tad too short in some cases. DEFAULT_RESPONSE_TIMEOUT = OMNI_RETRANSMIT_TIME * OMNI_RETRANSMIT_COUNT # Default timeout for receiving responses # Network Constants diff --git a/pyomnilogic_local/api/message.py b/pyomnilogic_local/api/message.py new file mode 100644 index 0000000..a8fee1a --- /dev/null +++ b/pyomnilogic_local/api/message.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import logging +import struct +import time +from typing import Self + +from pyomnilogic_local.omnitypes import ClientType, MessageType + +from .constants import ( + PROTOCOL_HEADER_FORMAT, + PROTOCOL_HEADER_SIZE, + PROTOCOL_VERSION, +) +from .exceptions import OmniMessageFormatError + +_LOGGER = logging.getLogger(__name__) + + +class OmniLogicMessage: + """A protocol message for communication with the OmniLogic controller. + + Handles serialization and deserialization of message headers and payloads. + """ + + header_format = PROTOCOL_HEADER_FORMAT + id: int + type: MessageType + payload: bytes + client_type: ClientType = ClientType.SIMPLE + version: str = PROTOCOL_VERSION + timestamp: int + reserved_1: int = 0 + compressed: bool = False + reserved_2: int = 0 + + def __init__( + self, + msg_id: int, + msg_type: MessageType, + payload: str | None = None, + version: str = PROTOCOL_VERSION, + timestamp: int | None = None, + ) -> None: + """Initialize a new OmniLogicMessage. + + Args: + msg_id: Unique message identifier. + msg_type: Type of message being sent. + payload: Optional string payload (XML or command body). + version: Protocol version string. + timestamp: Optional timestamp for the message. + """ + self.id = msg_id + self.type = msg_type + # If we are speaking the XML API, it seems like we need client_type 0, otherwise we need client_type 1 + self.client_type = ClientType.XML if payload is not None else ClientType.SIMPLE + # The Hayward API terminates it's messages with a null character + payload = f"{payload}\x00" if payload is not None else "" + self.payload = bytes(payload, "utf-8") + + self.version = version + self.timestamp = timestamp if timestamp is not None else int(time.time()) + + def __bytes__(self) -> bytes: + """Serialize the message to bytes for UDP transmission. + + Returns: + Byte representation of the message. + """ + header = struct.pack( + self.header_format, + self.id, # Msg id + self.timestamp, + bytes(self.version, "ascii"), # version string + self.type.value, # OpID/msgType + self.client_type.value, # Client type + 0, # reserved + self.compressed, # compressed + 0, # reserved + ) + return header + self.payload + + def __repr__(self) -> str: + """Return a string representation of the message for debugging.""" + if self.compressed or self.type is MessageType.MSP_BLOCKMESSAGE: + return f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}" + return ( + f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}, " + f"Body: {self.payload[:-1].decode('utf-8')}" + ) + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + """Parse a message from its byte representation. + + Args: + data: Byte data received from the controller. + + Returns: + OmniLogicMessage instance. + + Raises: + OmniMessageFormatException: If the message format is invalid. + """ + if len(data) < PROTOCOL_HEADER_SIZE: + msg = f"Message too short: {len(data)} bytes, expected at least {PROTOCOL_HEADER_SIZE}" + raise OmniMessageFormatError(msg) + + # split the header and data + header = data[:PROTOCOL_HEADER_SIZE] + rdata: bytes = data[PROTOCOL_HEADER_SIZE:] + + try: + (msg_id, tstamp, vers, msg_type, client_type, res1, compressed, res2) = struct.unpack(cls.header_format, header) + except struct.error as exc: + msg = f"Failed to unpack message header: {exc}" + raise OmniMessageFormatError(msg) from exc + + # Validate message type + try: + message_type_enum = MessageType(msg_type) + except ValueError as exc: + msg = f"Unknown message type: {msg_type}: {exc}" + raise OmniMessageFormatError(msg) from exc + + # Validate client type + try: + client_type_enum = ClientType(int(client_type)) + except ValueError as exc: + msg = f"Unknown client type: {client_type}: {exc}" + raise OmniMessageFormatError(msg) from exc + + message = cls(msg_id=msg_id, msg_type=message_type_enum, version=vers.decode("utf-8")) + message.timestamp = tstamp + message.client_type = client_type_enum + message.reserved_1 = res1 + # There are some messages that are ALWAYS compressed although they do not return a 1 in their LeadMessage + message.compressed = compressed == 1 or message.type in [MessageType.MSP_TELEMETRY_UPDATE] + message.reserved_2 = res2 + message.payload = rdata + + return message diff --git a/pyomnilogic_local/api/protocol.py b/pyomnilogic_local/api/protocol.py index ba31a96..c816b85 100644 --- a/pyomnilogic_local/api/protocol.py +++ b/pyomnilogic_local/api/protocol.py @@ -1,16 +1,16 @@ +"""Asyncio UDP datagram protocol for communication with the OmniLogic controller.""" + from __future__ import annotations import asyncio import logging import random -import struct -import time import xml.etree.ElementTree as ET import zlib -from typing import Any, Self, cast +from typing import cast from pyomnilogic_local.models.leadmessage import LeadMessage -from pyomnilogic_local.omnitypes import ClientType, MessageType +from pyomnilogic_local.omnitypes import MessageType from .constants import ( ACK_WAIT_TIMEOUT, @@ -19,455 +19,330 @@ MAX_FRAGMENT_WAIT_TIME, MAX_QUEUE_SIZE, OMNI_RETRANSMIT_COUNT, - OMNI_RETRANSMIT_TIME, - PROTOCOL_HEADER_FORMAT, - PROTOCOL_HEADER_SIZE, - PROTOCOL_VERSION, - XML_ENCODING, XML_NAMESPACE, ) -from .exceptions import OmniFragmentationError, OmniMessageFormatError, OmniTimeoutError +from .exceptions import OmniConnectionError, OmniMessageFormatError, OmniTimeoutError +from .message import OmniLogicMessage _LOGGER = logging.getLogger(__name__) +_ACK_PAYLOAD = f'\nAck\n' -class OmniLogicMessage: - """A protocol message for communication with the OmniLogic controller. - - Handles serialization and deserialization of message headers and payloads. - """ +_ACK_TYPES = frozenset({MessageType.ACK, MessageType.XML_ACK}) - header_format = PROTOCOL_HEADER_FORMAT - id: int - type: MessageType - payload: bytes - client_type: ClientType = ClientType.SIMPLE - version: str = PROTOCOL_VERSION - timestamp: int - reserved_1: int = 0 - compressed: bool = False - reserved_2: int = 0 - - def __init__( - self, - msg_id: int, - msg_type: MessageType, - payload: str | None = None, - version: str = PROTOCOL_VERSION, - timestamp: int | None = None, - ) -> None: - """Initialize a new OmniLogicMessage. +# Type alias for items placed on the receive queue: either a parsed message or a parse error. +_QueueItem = OmniLogicMessage | OmniMessageFormatError - Args: - msg_id: Unique message identifier. - msg_type: Type of message being sent. - payload: Optional string payload (XML or command body). - version: Protocol version string. - timestamp: Optional timestamp for the message. - """ - self.id = msg_id - self.type = msg_type - # If we are speaking the XML API, it seems like we need client_type 0, otherwise we need client_type 1 - self.client_type = ClientType.XML if payload is not None else ClientType.SIMPLE - # The Hayward API terminates it's messages with a null character - payload = f"{payload}\x00" if payload is not None else "" - self.payload = bytes(payload, "utf-8") - self.version = version - self.timestamp = timestamp if timestamp is not None else int(time.time()) +class OmniLogicProtocol(asyncio.DatagramProtocol): + """Asyncio UDP datagram protocol for communication with the OmniLogic controller. - def __bytes__(self) -> bytes: - """Serialize the message to bytes for UDP transmission. + Handles message framing, acknowledgement, retransmission, and multi-part + response reassembly for the Hayward OmniLogic local UDP protocol. - Returns: - Byte representation of the message. - """ - header = struct.pack( - self.header_format, - self.id, # Msg id - self.timestamp, - bytes(self.version, "ascii"), # version string - self.type.value, # OpID/msgType - self.client_type.value, # Client type - 0, # reserved - self.compressed, # compressed - 0, # reserved - ) - return header + self.payload - - def __repr__(self) -> str: - """Return a string representation of the message for debugging.""" - if self.compressed or self.type is MessageType.MSP_BLOCKMESSAGE: - return f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}" - return ( - f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}, " - f"Body: {self.payload[:-1].decode('utf-8')}" + Example: + loop = asyncio.get_running_loop() + transport, protocol = await loop.create_datagram_endpoint( + OmniLogicProtocol, remote_addr=(controller_ip, controller_port) ) + try: + response = await protocol.async_send_and_receive(MessageType.GET_TELEMETRY, xml_body) + finally: + transport.close() + """ - @classmethod - def from_bytes(cls, data: bytes) -> Self: - """Parse a message from its byte representation. + def __init__(self) -> None: + self._transport: asyncio.DatagramTransport | None = None + # Seed with a random value so each protocol instance (one per API call) uses distinct IDs. + # Message ID is an unsigned 32-bit integer in the wire format, so cap at 2**16 to + # leave room for plenty of increments. + self._msg_counter: int = random.randint(1, 2**16) + self._ack_futures: dict[int, asyncio.Future[OmniLogicMessage]] = {} + self._receive_queue: asyncio.Queue[_QueueItem] = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) + + # ------------------------------------------------------------------------- + # asyncio.DatagramProtocol callbacks + # ------------------------------------------------------------------------- - Args: - data: Byte data received from the controller. + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self._transport = cast("asyncio.DatagramTransport", transport) + _LOGGER.debug("connection established") - Returns: - OmniLogicMessage instance. + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: + try: + msg = OmniLogicMessage.from_bytes(data) + except OmniMessageFormatError as exc: + _LOGGER.warning("received unparsable datagram from %s: %s", addr, exc) + self._receive_queue.put_nowait(exc) + return - Raises: - OmniMessageFormatException: If the message format is invalid. - """ - if len(data) < PROTOCOL_HEADER_SIZE: - msg = f"Message too short: {len(data)} bytes, expected at least {PROTOCOL_HEADER_SIZE}" - raise OmniMessageFormatError(msg) + _LOGGER.debug("received from %s: %r", addr, msg) - # split the header and data - header = data[:PROTOCOL_HEADER_SIZE] - rdata: bytes = data[PROTOCOL_HEADER_SIZE:] + if msg.type in _ACK_TYPES: + self._resolve_ack(msg) + else: + self._send_xml_ack(msg.id) + self._receive_queue.put_nowait(msg) - try: - (msg_id, tstamp, vers, msg_type, client_type, res1, compressed, res2) = struct.unpack(cls.header_format, header) - except struct.error as exc: - msg = f"Failed to unpack message header: {exc}" - raise OmniMessageFormatError(msg) from exc + def error_received(self, exc: Exception) -> None: + _LOGGER.error("transport error: %s", exc) - # Validate message type - try: - message_type_enum = MessageType(msg_type) - except ValueError as exc: - msg = f"Unknown message type: {msg_type}: {exc}" - raise OmniMessageFormatError(msg) from exc + def connection_lost(self, exc: Exception | None) -> None: + _LOGGER.debug("connection lost: %s", exc) + + # ------------------------------------------------------------------------- + # Internal helpers + # ------------------------------------------------------------------------- + + def _next_msg_id(self) -> int: + """Return the next sequential message ID.""" + self._msg_counter += 1 + return self._msg_counter + + def _resolve_ack(self, msg: OmniLogicMessage) -> None: + """Resolve the pending ACK future for the given message ID.""" + future = self._ack_futures.get(msg.id) + if future is not None and not future.done(): + future.set_result(msg) + + def _send_xml_ack(self, msg_id: int) -> None: + """Transmit an XML_ACK for a received message.""" + if self._transport is None: + _LOGGER.warning("cannot send ACK for ID %d, transport unavailable", msg_id) + return + ack = OmniLogicMessage(msg_id=msg_id, msg_type=MessageType.XML_ACK, payload=_ACK_PAYLOAD) + self._transport.sendto(bytes(ack)) + _LOGGER.debug("sent XML_ACK for message ID %d", msg_id) + + def _build_request(self, msg_type: MessageType, payload: str) -> OmniLogicMessage: + """Build a new outgoing request message with a fresh ID.""" + return OmniLogicMessage(msg_id=self._next_msg_id(), msg_type=msg_type, payload=payload) + + async def _send_with_retry(self, msg_type: MessageType, payload: str) -> None: + """Transmit a message and wait for acknowledgement, retransmitting as needed. + + A fresh message ID is used on each attempt so the controller treats every + retransmission as a new request (an ACK only confirms receipt/parse, not + that the controller will act on or re-respond to the same ID again). - # Validate client type - try: - client_type_enum = ClientType(int(client_type)) - except ValueError as exc: - msg = f"Unknown client type: {client_type}: {exc}" - raise OmniMessageFormatError(msg) from exc + Args: + msg_type: The type of message to send. + payload: The XML payload string. - message = cls(msg_id=msg_id, msg_type=message_type_enum, version=vers.decode("utf-8")) - message.timestamp = tstamp - message.client_type = client_type_enum - message.reserved_1 = res1 - # There are some messages that are ALWAYS compressed although they do not return a 1 in their LeadMessage - message.compressed = compressed == 1 or message.type in [MessageType.MSP_TELEMETRY_UPDATE] - message.reserved_2 = res2 - message.payload = rdata + Raises: + OmniConnectionError: If the transport is not available. + OmniTimeoutError: If no ACK is received after all retransmission attempts. + """ + if self._transport is None: + msg = f"Cannot send message type {msg_type.name}, transport not available" + raise OmniConnectionError(msg) - return message + loop = asyncio.get_running_loop() + for attempt in range(OMNI_RETRANSMIT_COUNT + 1): + message = self._build_request(msg_type, payload) + ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future() + self._ack_futures[message.id] = ack_future -class OmniLogicProtocol(asyncio.DatagramProtocol): - """Asyncio DatagramProtocol implementation for OmniLogic UDP communication. + try: + _LOGGER.debug( + "transmitting message ID: %d, type: %s (attempt %d/%d)", + message.id, + message.type.name, + attempt + 1, + OMNI_RETRANSMIT_COUNT + 1, + ) + self._transport.sendto(bytes(message)) + try: + await asyncio.wait_for(asyncio.shield(ack_future), timeout=ACK_WAIT_TIMEOUT) + except TimeoutError: + if attempt < OMNI_RETRANSMIT_COUNT: + _LOGGER.debug("no ACK for message ID %d, will retry", message.id) + else: + return + finally: + self._ack_futures.pop(message.id, None) + if not ack_future.done(): + ack_future.cancel() - Handles message sending, receiving, retries, and block message reassembly. - """ + msg = f"No ACK received for message type {msg_type.name} after {OMNI_RETRANSMIT_COUNT + 1} attempts" + raise OmniTimeoutError(msg) - transport: asyncio.DatagramTransport - # The omni will re-transmit a packet every 2 seconds if it does not receive an ACK. We pad that just a touch to be safe - _omni_retransmit_time = OMNI_RETRANSMIT_TIME - # The omni will re-transmit 5 times (a total of 6 attempts including the initial) if it does not receive an ACK - _omni_retransmit_count = OMNI_RETRANSMIT_COUNT + async def _receive_next_message(self) -> OmniLogicMessage: + """Wait for and return the next incoming (non-ACK) message from the queue. - data_queue: asyncio.Queue[OmniLogicMessage] - error_queue: asyncio.Queue[Exception] + Raises: + OmniTimeoutError: If no message arrives within MAX_FRAGMENT_WAIT_TIME seconds. + OmniMessageFormatError: If the queued item is a parse error. + """ + try: + async with asyncio.timeout(MAX_FRAGMENT_WAIT_TIME): + item: _QueueItem = await self._receive_queue.get() + except TimeoutError as exc: + msg = "Timed out waiting for response message from controller" + raise OmniTimeoutError(msg) from exc - def __init__(self) -> None: - """Initialize the protocol handler and message queue.""" - self.data_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) - self.error_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) + if isinstance(item, OmniMessageFormatError): + raise item - def connection_made(self, transport: asyncio.BaseTransport) -> None: - """Called when a UDP connection is made.""" - self.transport = cast("asyncio.DatagramTransport", transport) + return item - def connection_lost(self, exc: Exception | None) -> None: - """Called when the UDP connection is lost or closed.""" - if exc: - raise exc + # ------------------------------------------------------------------------- + # Response assembly + # ------------------------------------------------------------------------- + + async def _receive_response(self) -> tuple[bytes, bool]: + """Receive and return the full response payload for a prior request. - def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None: - """Called when a datagram is received from the controller. + Handles both single-message responses and multi-part lead/block responses. - Parses the message and puts it on the queue. Handles corrupt or unexpected data gracefully. + Returns: + Tuple of (raw_payload_bytes, compressed_flag). """ - try: - message = OmniLogicMessage.from_bytes(data) - _LOGGER.debug("Received Message %s from %s", str(message), addr) - try: - self.data_queue.put_nowait(message) - except asyncio.QueueFull: - _LOGGER.exception("Data queue is full. Dropping message: %s", str(message)) - except OmniMessageFormatError as exc: - _LOGGER.exception("Failed to parse incoming datagram from %s", addr) - self.error_queue.put_nowait(exc) - except Exception as exc: - _LOGGER.exception("Unexpected error processing datagram from %s", addr) - self.error_queue.put_nowait(exc) + while True: + msg = await self._receive_next_message() - def error_received(self, exc: Exception) -> None: - """Called when a UDP error is received. + if msg.type == MessageType.MSP_LEADMESSAGE: + return await self._reassemble_multipart(msg) - Store the error so it can be handled by awaiting coroutines. - """ - self.error_queue.put_nowait(exc) + if msg.type == MessageType.MSP_BLOCKMESSAGE: + _LOGGER.warning("received block message ID %d before any lead message, ignoring", msg.id) + continue - async def _wait_for_ack(self, ack_id: int) -> None: - """Wait for an ACK message with the given ID. + return msg.payload, msg.compressed - Handles dropped or out-of-order ACKs. + async def _reassemble_multipart(self, lead_msg: OmniLogicMessage) -> tuple[bytes, bool]: + """Reassemble a multi-part response from a lead message and its block messages. Args: - ack_id: The message ID to wait for an ACK. + lead_msg: The initial MSP_LEADMESSAGE received from the controller. - Raises: - OmniTimeoutException: If no ACK is received. - Exception: If a protocol error occurs. + Returns: + Tuple of (concatenated_block_payload_bytes, compressed_flag). """ - # Wait for either an ACK message or an error - # Race condition: datagram_received() calls put_nowait() synchronously, so data_task may - # already be done when wait_for fires its timeout CancelledError. In that case we catch - # the cancellation, skip re-looping, and fall through to check the result below. If the - # result is our ACK we return normally, suppressing the CancelledError so wait_for treats - # the call as successful. If it isn't, we re-raise after the loop. - cancelled: asyncio.CancelledError | None = None - retry = True - while retry: - # Wait for either a message or an error - data_task = asyncio.create_task(self.data_queue.get()) - error_task = asyncio.create_task(self.error_queue.get()) - try: - done, pending = await asyncio.wait([data_task, error_task], return_when=asyncio.FIRST_COMPLETED) - except asyncio.CancelledError as exc: - retry = False - cancelled = exc - done = {t for t in (data_task, error_task) if t.done()} - pending = {t for t in (data_task, error_task) if not t.done()} - - # Cancel any pending tasks to avoid "Task was destroyed but it is pending" warnings - for task in pending: - task.cancel() - - if error_task in done: - err = error_task.result() - if isinstance(err, Exception): - raise err - _LOGGER.error("Unknown error occurred during communication with OmniLogic: %s", err) - if data_task in done: - message = data_task.result() - if message.id == ack_id: - _LOGGER.debug("Received ACK for message ID %s", ack_id) - return - _LOGGER.debug("We received a message that is not our ACK, it appears the ACK was dropped") - if message.type in {MessageType.MSP_LEADMESSAGE, MessageType.MSP_TELEMETRY_UPDATE}: - _LOGGER.debug("Omni has sent a new message, continuing on with the communication") - await self.data_queue.put(message) - return + lead = self._parse_lead_message(lead_msg) + compressed = lead_msg.compressed + seen_lead_ids: set[int] = {lead_msg.id} + + _LOGGER.debug( + "reassembling %d-block response (compressed=%s)", + lead.msg_block_count, + compressed, + ) - if cancelled is not None: - raise cancelled + payload_data = b"" + received_block_count = 0 + seen_block_ids: set[int] = set() - async def _ensure_sent( - self, - message: OmniLogicMessage, - max_attempts: int = 5, - ) -> None: - """Send a message and ensure it is acknowledged, retrying if necessary. + while received_block_count < lead.msg_block_count: + msg = await self._receive_next_message() - Args: - message: The message to send. - max_attempts: Maximum number of send attempts. + if msg.type == MessageType.MSP_LEADMESSAGE: + if msg.id not in seen_lead_ids: + _LOGGER.warning("received unexpected secondary lead message ID %d", msg.id) + seen_lead_ids.add(msg.id) + else: + _LOGGER.debug("received duplicate lead message ID %d, re-ACK already sent", msg.id) + continue - Raises: - OmniTimeoutException: If no ACK is received after retries. - """ - for attempt in range(max_attempts): - self.transport.sendto(bytes(message)) - _LOGGER.debug("Sent message ID %s (attempt %d/%d)", message.id, attempt + 1, max_attempts) + if msg.type != MessageType.MSP_BLOCKMESSAGE: + _LOGGER.warning("expected block message but got %s (ID %d), ignoring", msg.type.name, msg.id) + continue - # If the message that we just sent is an ACK, we do not need to wait to receive an ACK, we are done - if message.type in [MessageType.XML_ACK, MessageType.ACK]: - return + if msg.id in seen_block_ids: + _LOGGER.debug("received duplicate block message ID %d, re-ACK already sent", msg.id) + continue - # Wait for a bit to either receive an ACK for our message, otherwise, we retry delivery - try: - await asyncio.wait_for(self._wait_for_ack(message.id), ACK_WAIT_TIMEOUT) - except TimeoutError as exc: - if attempt < max_attempts - 1: - _LOGGER.warning( - "ACK not received for message type %s (ID: %s), attempt %d/%d. Retrying...", - message.type.name, - message.id, - attempt + 1, - max_attempts, - ) - else: - _LOGGER.exception( - "Failed to receive ACK for message type %s (ID: %s) after %d attempts.", message.type.name, message.id, max_attempts - ) - msg = f"Failed to receive acknowledgement of command, max retries exceeded: {exc}" - raise OmniTimeoutError(msg) from exc - else: - return - - async def send_and_receive( - self, - msg_type: MessageType, - payload: str | None, - msg_id: int | None = None, - response_timeout: float = DEFAULT_RESPONSE_TIMEOUT, - ) -> str: - """Send a message and wait for a response, returning the response payload as a string. + seen_block_ids.add(msg.id) + payload_data += msg.payload[BLOCK_MESSAGE_HEADER_OFFSET:] + received_block_count += 1 + _LOGGER.debug("received block %d/%d", received_block_count, lead.msg_block_count) + + return payload_data, compressed + + def _parse_lead_message(self, msg: OmniLogicMessage) -> LeadMessage: + """Parse the XML payload of an MSP_LEADMESSAGE into a LeadMessage model. Args: - msg_type: Type of message to send. - payload: Optional payload string. - msg_id: Optional message ID. - response_timeout: Timeout in seconds to wait for the response. + msg: The MSP_LEADMESSAGE to parse. Returns: - Response payload as a string. + Parsed LeadMessage model. """ - await self.send_message(msg_type, payload, msg_id) - return await self._receive_file(response_timeout=response_timeout) - - # Send a message that you do NOT need a response to - async def send_message( - self, - msg_type: MessageType, - payload: str | None, - msg_id: int | None = None, - ) -> None: - """Send a message that does not require a response. + payload_str = msg.payload.decode("utf-8").strip("\x00") + root = ET.fromstring(payload_str) + return LeadMessage.model_validate(root) + + def _decode_payload(self, data: bytes, compressed: bool) -> str: + """Decode a raw response payload, decompressing if necessary. Args: - msg_type: Type of message to send. - payload: Optional payload string. - msg_id: Optional message ID. + data: Raw payload bytes. + compressed: Whether the payload is zlib-compressed. + + Returns: + Decoded UTF-8 string with leading/trailing null bytes stripped. """ - # If we aren't sending a specific msg_id, lets randomize it - if not msg_id: - msg_id = random.randrange(2**32) + if compressed: + data = zlib.decompress(data.rstrip(b"\x00")) + return data.decode("utf-8").strip("\x00") - message = OmniLogicMessage(msg_id, msg_type, payload) + # ------------------------------------------------------------------------- + # Public API + # ------------------------------------------------------------------------- - _LOGGER.debug("Sending Message %s", str(message)) + async def async_send(self, msg_type: MessageType, payload: str) -> None: + """Send a message to the controller and wait for acknowledgement. - await self._ensure_sent(message) + The message is retransmitted up to OMNI_RETRANSMIT_COUNT times if no ACK + is received within ACK_WAIT_TIMEOUT seconds. - async def _send_ack(self, msg_id: int) -> None: - """Send an ACK message for the given message ID.""" - body_element = ET.Element("Request", {"xmlns": XML_NAMESPACE}) - name_element = ET.SubElement(body_element, "Name") - name_element.text = "Ack" + Args: + msg_type: The type of message to send. + payload: The XML payload string. - req_body = ET.tostring(body_element, xml_declaration=True, encoding=XML_ENCODING) - await self.send_message(MessageType.XML_ACK, req_body, msg_id) + Raises: + OmniConnectionError: If the transport is not available. + OmniTimeoutError: If no ACK is received after all retransmission attempts. + """ + await self._send_with_retry(msg_type, payload) - async def _receive_file(self, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT) -> str: - """Wait for and reassemble a full response from the controller. + async def async_send_and_receive(self, msg_type: MessageType, payload: str) -> str: + """Send a message and receive the controller's response payload. - Handles single and multi-block (LeadMessage/BlockMessage) responses. + Handles the full send → ACK → response (single or lead/block) flow, + including retransmission on ACK timeout and decompression of compressed + responses. If an ACK is received but the controller never sends the + expected follow-up response within DEFAULT_RESPONSE_TIMEOUT, the entire + send+receive cycle is retried up to OMNI_RETRANSMIT_COUNT times. Args: - response_timeout: Timeout in seconds to wait for the initial response. + msg_type: The type of message to send. + payload: The XML payload string. Returns: - Response payload as a string. + The decoded UTF-8 response string from the controller. Raises: - OmniTimeoutException: If a block message is not received in time. - OmniFragmentationException: If fragment reassembly fails. + OmniConnectionError: If the transport is not available. + OmniTimeoutError: If no ACK or response is received within the allowed time. + OmniMessageFormatError: If an unparsable response datagram is received. """ - # wait for the initial packet. - try: - message = await asyncio.wait_for(self.data_queue.get(), response_timeout) - except TimeoutError as exc: - msg = f"Timeout waiting for response from controller: {exc}" - raise OmniTimeoutError(msg) from exc - - # If messages have to be re-transmitted, we can sometimes receive multiple ACKs. The first one would be handled by - # self._ensure_sent, but if any subsequent ACKs are sent to us, we need to dump them and wait for a "real" message. - while message.type in [MessageType.ACK, MessageType.XML_ACK]: - _LOGGER.debug("Skipping duplicate ACK message") + for attempt in range(OMNI_RETRANSMIT_COUNT + 1): + await self._send_with_retry(msg_type, payload) try: - message = await asyncio.wait_for(self.data_queue.get(), response_timeout) - except TimeoutError as exc: - msg = f"Timeout waiting for response from controller: {exc}" - raise OmniTimeoutError(msg) from exc - - await self._send_ack(message.id) - - # If the response is too large, the controller will send a LeadMessage indicating how many follow-up messages will be sent - if message.type is MessageType.MSP_LEADMESSAGE: - try: - leadmsg = LeadMessage.model_validate(ET.fromstring(message.payload[:-1])) - except Exception as exc: - msg = f"Failed to parse LeadMessage: {exc}" - raise OmniFragmentationError(msg) from exc - - _LOGGER.debug("Will receive %s blockmessages for fragmented response", leadmsg.msg_block_count) - - # Wait for the block data data - retval: bytes = b"" - # If we received a LeadMessage, continue to receive messages until we have all of our data - # Fragments of data may arrive out of order, so we store them in a buffer as they arrive and sort them after - data_fragments: dict[int, bytes] = {} - fragment_start_time = time.time() - - while len(data_fragments) < leadmsg.msg_block_count: - # Check if we've been waiting too long for fragments - if time.time() - fragment_start_time > MAX_FRAGMENT_WAIT_TIME: - _LOGGER.error( - "Timeout waiting for fragments: received %d/%d after %ds", - len(data_fragments), - leadmsg.msg_block_count, - MAX_FRAGMENT_WAIT_TIME, - ) - msg = ( - f"Timeout waiting for fragments: received {len(data_fragments)}/{leadmsg.msg_block_count} " - f"after {MAX_FRAGMENT_WAIT_TIME}s" + async with asyncio.timeout(DEFAULT_RESPONSE_TIMEOUT): + raw_data, compressed = await self._receive_response() + return self._decode_payload(raw_data, compressed) + except TimeoutError: + if attempt < OMNI_RETRANSMIT_COUNT: + _LOGGER.debug( + "no response received for %s within %ds, retrying (attempt %d/%d)", + msg_type.name, + DEFAULT_RESPONSE_TIMEOUT, + attempt + 1, + OMNI_RETRANSMIT_COUNT + 1, ) - raise OmniFragmentationError(msg) - - # We need to wait long enough for the Omni to get through all of it's retries before we bail out. - try: - resp = await asyncio.wait_for(self.data_queue.get(), response_timeout) - except TimeoutError as exc: - msg = f"Timeout receiving fragment: got {len(data_fragments)}/{leadmsg.msg_block_count} fragments: {exc}" - raise OmniFragmentationError(msg) from exc - - # We only want to collect blockmessages here - if resp.type is not MessageType.MSP_BLOCKMESSAGE: - _LOGGER.debug("Received a message other than a blockmessage during fragmentation: %s", resp.type) - continue - - await self._send_ack(resp.id) - - # remove an 8 byte header to get to the payload data - data_fragments[resp.id] = resp.payload[BLOCK_MESSAGE_HEADER_OFFSET:] - _LOGGER.debug("Received fragment %d/%d", len(data_fragments), leadmsg.msg_block_count) - # Reassemble the fragmets in order - for _, data in sorted(data_fragments.items()): - retval += data - - _LOGGER.debug("Successfully reassembled %d fragments into %d bytes", leadmsg.msg_block_count, len(retval)) - - # We did not receive a LeadMessage, so our payload is just this one packet - else: - retval = message.payload - - # Decompress the returned data if necessary - if message.compressed: - _LOGGER.debug("Decompressing response payload") - try: - comp_bytes = bytes.fromhex(retval.hex()) - retval = zlib.decompress(comp_bytes) - _LOGGER.debug("Decompressed %d bytes to %d bytes", len(comp_bytes), len(retval)) - except zlib.error as exc: - msg = f"Failed to decompress message: {exc}" - raise OmniMessageFormatError(msg) from exc - - # For some API calls, the Omni null terminates the response, we are stripping that here to make parsing it later easier - return retval.decode("utf-8").strip("\x00") + msg = f"No response received for {msg_type.name} after {OMNI_RETRANSMIT_COUNT + 1} attempts" + raise OmniTimeoutError(msg) diff --git a/pyomnilogic_local/cli/pcap_utils.py b/pyomnilogic_local/cli/pcap_utils.py index cdfa036..f4e91c9 100644 --- a/pyomnilogic_local/cli/pcap_utils.py +++ b/pyomnilogic_local/cli/pcap_utils.py @@ -15,7 +15,7 @@ from scapy.layers.inet import UDP from scapy.utils import rdpcap -from pyomnilogic_local.api.protocol import OmniLogicMessage +from pyomnilogic_local.api.message import OmniLogicMessage from pyomnilogic_local.models.leadmessage import LeadMessage from pyomnilogic_local.omnitypes import MessageType diff --git a/pyomnilogic_local/models/leadmessage.py b/pyomnilogic_local/models/leadmessage.py index d3f39f4..8817ef6 100644 --- a/pyomnilogic_local/models/leadmessage.py +++ b/pyomnilogic_local/models/leadmessage.py @@ -1,3 +1,4 @@ +# ruff: noqa: TC001 # pydantic relies on the omnitypes imports at runtime from __future__ import annotations from typing import Any @@ -5,6 +6,8 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator +from pyomnilogic_local.omnitypes import MessageType + from .const import XML_NS # Example Lead Message XML: @@ -24,7 +27,7 @@ class LeadMessage(BaseModel): model_config = ConfigDict(from_attributes=True) - source_op_id: int = Field(alias="SourceOpId") + source_op_id: MessageType = Field(alias="SourceOpId") msg_size: int = Field(alias="MsgSize") msg_block_count: int = Field(alias="MsgBlockCount") type: int = Field(alias="Type") diff --git a/pyomnilogic_local/omnitypes.py b/pyomnilogic_local/omnitypes.py index 6a16062..1831bf9 100644 --- a/pyomnilogic_local/omnitypes.py +++ b/pyomnilogic_local/omnitypes.py @@ -32,8 +32,8 @@ class MessageType(PrettyEnum, IntEnum): GET_FILTER_DIAGNOSTIC_INFO = 386 HANDSHAKE = 1000 ACK = 1002 - MSP_TELEMETRY_UPDATE = 1004 MSP_CONFIGURATIONUPDATE = 1003 + MSP_TELEMETRY_UPDATE = 1004 MSP_ALARM_LIST_RESPONSE = 1304 MSP_LEADMESSAGE = 1998 MSP_BLOCKMESSAGE = 1999 @@ -324,6 +324,7 @@ class ZodiacShow(PrettyEnum, IntEnum): class ColorLogicPowerState(PrettyEnum, IntEnum): OFF = 0 POWERING_OFF = 1 + INITIALIZING = 2 # The app shows this as 15 seconds of white, but this state seems to happen when the Omni first powers up CHANGING_SHOW = 3 FIFTEEN_SECONDS_WHITE = 4 ACTIVE = 6 diff --git a/tests/test_api.py b/tests/test_api.py index 3ed3775..019b087 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -186,7 +186,7 @@ async def test_async_get_mspconfig_generates_valid_xml() -> None: """Test that async_get_mspconfig generates valid XML request.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send_and_receive", new_callable=AsyncMock) as mock_send: mock_send.return_value = 'Configuration' await api.async_get_mspconfig(raw=True) @@ -206,7 +206,7 @@ async def test_async_get_telemetry_generates_valid_xml() -> None: """Test that async_get_telemetry generates valid XML request.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send_and_receive", new_callable=AsyncMock) as mock_send: mock_send.return_value = 'Telemetry' await api.async_get_telemetry(raw=True) @@ -226,7 +226,7 @@ async def test_async_get_filter_diagnostics_generates_valid_xml() -> None: """Test that async_get_filter_diagnostics generates valid XML with correct parameters.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send_and_receive", new_callable=AsyncMock) as mock_send: mock_send.return_value = 'FilterDiagnostics' await api.async_get_filter_diagnostics(pool_id=1, equipment_id=2, raw=True) @@ -248,7 +248,7 @@ async def test_async_set_heater_generates_valid_xml() -> None: """Test that async_set_heater generates valid XML with correct parameters.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_heater(pool_id=1, equipment_id=2, temperature=75) @@ -273,7 +273,7 @@ async def test_async_set_filter_speed_generates_valid_xml() -> None: """Test that async_set_filter_speed generates valid XML with correct parameters.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_filter_speed(pool_id=1, equipment_id=2, speed=75) @@ -296,7 +296,7 @@ async def test_async_set_equipment_generates_valid_xml() -> None: """Test that async_set_equipment generates valid XML with correct parameters.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_equipment( @@ -339,7 +339,7 @@ async def test_async_set_heater_mode_generates_valid_xml() -> None: """Test that async_set_heater_mode generates valid XML with correct enum values.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_heater_mode(pool_id=1, equipment_id=2, mode=HeaterMode.HEAT) @@ -360,7 +360,7 @@ async def test_async_set_light_show_generates_valid_xml() -> None: """Test that async_set_light_show generates valid XML with correct enum values.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_light_show( @@ -399,7 +399,7 @@ async def test_async_set_chlorinator_enable_boolean_conversion(subtests: pytest. ] for enabled, expected, description in test_cases: - with subtests.test(msg=description), patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with subtests.test(msg=description), patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_chlorinator_enable(pool_id=1, enabled=enabled) @@ -424,7 +424,7 @@ async def test_async_set_heater_enable_boolean_conversion(subtests: pytest.Subte ] for enabled, expected, description in test_cases: - with subtests.test(msg=description), patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with subtests.test(msg=description), patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_heater_enable(pool_id=1, equipment_id=2, enabled=enabled) @@ -441,7 +441,7 @@ async def test_async_set_chlorinator_params_generates_valid_xml() -> None: """Test that async_set_chlorinator_params generates valid XML with all parameters.""" api = OmniLogicAPI("192.168.1.100") - with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send: + with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send: mock_send.return_value = None await api.async_set_chlorinator_params( @@ -489,12 +489,12 @@ async def test_async_send_message_creates_transport() -> None: mock_transport = MagicMock() mock_protocol = AsyncMock() - mock_protocol.send_message = AsyncMock() + mock_protocol.async_send = AsyncMock() with patch("asyncio.get_running_loop") as mock_loop: mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol)) - await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=False) + await api.async_send(MessageType.REQUEST_CONFIGURATION, "test") # Verify endpoint was created with correct parameters mock_loop.return_value.create_datagram_endpoint.assert_called_once() @@ -512,15 +512,15 @@ async def test_async_send_message_with_response() -> None: mock_transport = MagicMock() mock_protocol = AsyncMock() - mock_protocol.send_and_receive = AsyncMock(return_value="test response") + mock_protocol.async_send_and_receive = AsyncMock(return_value="test response") with patch("asyncio.get_running_loop") as mock_loop: mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol)) - result = await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=True) + result = await api.async_send_and_receive(MessageType.REQUEST_CONFIGURATION, "test") assert result == "test response" - mock_protocol.send_and_receive.assert_called_once() + mock_protocol.async_send_and_receive.assert_called_once() mock_transport.close.assert_called_once() @@ -531,15 +531,15 @@ async def test_async_send_message_without_response() -> None: mock_transport = MagicMock() mock_protocol = AsyncMock() - mock_protocol.send_message = AsyncMock() + mock_protocol.async_send = AsyncMock() with patch("asyncio.get_running_loop") as mock_loop: mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol)) - result = await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=False) # type: ignore[func-returns-value] + result = await api.async_send(MessageType.REQUEST_CONFIGURATION, "test") # type: ignore[func-returns-value] assert result is None - mock_protocol.send_message.assert_called_once() + mock_protocol.async_send.assert_called_once() mock_transport.close.assert_called_once() @@ -550,13 +550,13 @@ async def test_async_send_message_closes_transport_on_error() -> None: mock_transport = MagicMock() mock_protocol = AsyncMock() - mock_protocol.send_message = AsyncMock(side_effect=Exception("Test error")) + mock_protocol.async_send = AsyncMock(side_effect=Exception("Test error")) with patch("asyncio.get_running_loop") as mock_loop: mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol)) with pytest.raises(Exception, match="Test error"): - await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=False) + await api.async_send(MessageType.REQUEST_CONFIGURATION, "test") # Verify transport was still closed despite the error mock_transport.close.assert_called_once() diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d152b39..93bf179 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -14,14 +14,15 @@ import struct import time import zlib -from typing import Any from unittest.mock import AsyncMock, MagicMock, patch from xml.etree import ElementTree as ET import pytest -from pyomnilogic_local.api.exceptions import OmniFragmentationError, OmniMessageFormatError, OmniTimeoutError -from pyomnilogic_local.api.protocol import OmniLogicMessage, OmniLogicProtocol +from pyomnilogic_local.api.constants import OMNI_RETRANSMIT_COUNT +from pyomnilogic_local.api.exceptions import OmniMessageFormatError, OmniTimeoutError +from pyomnilogic_local.api.message import OmniLogicMessage +from pyomnilogic_local.api.protocol import OmniLogicProtocol from pyomnilogic_local.omnitypes import ClientType, MessageType # ============================================================================ @@ -187,8 +188,7 @@ def test_message_payload_null_termination() -> None: def test_protocol_initialization() -> None: """Test that protocol initializes with correct queue sizes.""" protocol = OmniLogicProtocol() - assert protocol.data_queue.maxsize == 100 - assert protocol.error_queue.maxsize == 100 + assert protocol._receive_queue.maxsize == 100 def test_protocol_connection_made() -> None: @@ -198,16 +198,7 @@ def test_protocol_connection_made() -> None: protocol.connection_made(mock_transport) - assert protocol.transport is mock_transport - - -def test_protocol_connection_lost_with_exception() -> None: - """Test that connection_lost raises exception if provided.""" - protocol = OmniLogicProtocol() - test_exception = RuntimeError("Connection error") - - with pytest.raises(RuntimeError, match="Connection error"): - protocol.connection_lost(test_exception) + assert protocol._transport is mock_transport def test_protocol_connection_lost_without_exception() -> None: @@ -224,14 +215,16 @@ def test_protocol_connection_lost_without_exception() -> None: def test_datagram_received_valid_message() -> None: """Test that valid messages are added to the queue.""" protocol = OmniLogicProtocol() - valid_data = bytes(OmniLogicMessage(123, MessageType.ACK)) + # ACK/XML_ACK messages resolve futures, not the queue; use a non-ACK type + valid_data = bytes(OmniLogicMessage(123, MessageType.MSP_CONFIGURATIONUPDATE)) protocol.datagram_received(valid_data, ("127.0.0.1", 12345)) - assert protocol.data_queue.qsize() == 1 - message = protocol.data_queue.get_nowait() + assert protocol._receive_queue.qsize() == 1 + message = protocol._receive_queue.get_nowait() + assert isinstance(message, OmniLogicMessage) assert message.id == 123 - assert message.type == MessageType.ACK + assert message.type == MessageType.MSP_CONFIGURATIONUPDATE def test_datagram_received_with_corrupt_data(caplog: pytest.LogCaptureFixture) -> None: @@ -239,51 +232,48 @@ def test_datagram_received_with_corrupt_data(caplog: pytest.LogCaptureFixture) - protocol = OmniLogicProtocol() corrupt_data = b"short" - with caplog.at_level("ERROR"): + with caplog.at_level("WARNING"): protocol.datagram_received(corrupt_data, ("127.0.0.1", 12345)) - assert any("Failed to parse incoming datagram" in r.message for r in caplog.records) - assert protocol.error_queue.qsize() == 1 + assert any("received unparsable datagram" in r.message for r in caplog.records) + # The parse error is placed on the receive queue as an OmniMessageFormatError + assert protocol._receive_queue.qsize() == 1 + assert isinstance(protocol._receive_queue.get_nowait(), OmniMessageFormatError) -def test_datagram_received_queue_overflow(caplog: pytest.LogCaptureFixture) -> None: - """Test that queue overflow is handled and logged.""" +def test_datagram_received_queue_overflow() -> None: + """Test that QueueFull is raised when the receive queue is full.""" protocol = OmniLogicProtocol() - protocol.data_queue = asyncio.Queue(maxsize=1) - protocol.data_queue.put_nowait(OmniLogicMessage(1, MessageType.ACK)) + protocol._receive_queue = asyncio.Queue(maxsize=1) + # ACKs resolve futures, not the queue; fill with a non-ACK message + protocol._receive_queue.put_nowait(OmniLogicMessage(1, MessageType.MSP_CONFIGURATIONUPDATE)) - valid_data = bytes(OmniLogicMessage(2, MessageType.ACK)) - with caplog.at_level("ERROR"): + valid_data = bytes(OmniLogicMessage(2, MessageType.MSP_CONFIGURATIONUPDATE)) + with pytest.raises(asyncio.QueueFull): protocol.datagram_received(valid_data, ("127.0.0.1", 12345)) - assert any("Data queue is full" in r.message for r in caplog.records) - -def test_datagram_received_unexpected_exception(caplog: pytest.LogCaptureFixture) -> None: - """Test that unexpected exceptions during datagram processing are handled.""" +def test_datagram_received_unexpected_exception() -> None: + """Test that unexpected exceptions during datagram processing propagate to the caller.""" protocol = OmniLogicProtocol() - # Patch OmniLogicMessage.from_bytes to raise an unexpected exception + # Only OmniMessageFormatError is caught; any other exception propagates unhandled with ( patch("pyomnilogic_local.api.protocol.OmniLogicMessage.from_bytes", side_effect=RuntimeError("Unexpected")), - caplog.at_level("ERROR"), + pytest.raises(RuntimeError, match="Unexpected"), ): protocol.datagram_received(b"data", ("127.0.0.1", 12345)) - assert any("Unexpected error processing datagram" in r.message for r in caplog.records) - assert protocol.error_queue.qsize() == 1 - -def test_error_received() -> None: - """Test that error_received puts errors in the error queue.""" +def test_error_received(caplog: pytest.LogCaptureFixture) -> None: + """Test that error_received logs transport errors.""" protocol = OmniLogicProtocol() test_error = RuntimeError("UDP error") - protocol.error_received(test_error) + with caplog.at_level("ERROR"): + protocol.error_received(test_error) - assert protocol.error_queue.qsize() == 1 - error = protocol.error_queue.get_nowait() - assert error is test_error + assert any("transport error" in r.message for r in caplog.records) # ============================================================================ @@ -293,91 +283,99 @@ def test_error_received() -> None: @pytest.mark.asyncio async def test_wait_for_ack_success() -> None: - """Test successful ACK waiting.""" + """Test that an ACK for the correct ID resolves the pending future.""" + loop = asyncio.get_running_loop() protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Put an ACK message in the queue + ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future() + protocol._ack_futures[123] = ack_future + ack_message = OmniLogicMessage(123, MessageType.ACK) - await protocol.data_queue.put(ack_message) + protocol._resolve_ack(ack_message) - # Should return without raising - await protocol._wait_for_ack(123) + assert ack_future.done() + assert ack_future.result() is ack_message @pytest.mark.asyncio async def test_wait_for_ack_wrong_id_continues_waiting() -> None: - """Test that wrong ACK IDs are consumed and waiting continues for the correct one.""" + """Test that an ACK for a different ID does not resolve the pending future.""" + loop = asyncio.get_running_loop() protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Put wrong ID first, then correct ID - wrong_ack = OmniLogicMessage(999, MessageType.ACK) - correct_ack = OmniLogicMessage(123, MessageType.ACK) + ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future() + protocol._ack_futures[123] = ack_future - await protocol.data_queue.put(wrong_ack) - await protocol.data_queue.put(correct_ack) + wrong_ack = OmniLogicMessage(999, MessageType.ACK) + protocol._resolve_ack(wrong_ack) - await protocol._wait_for_ack(123) - # Queue should be empty after consuming both messages - assert protocol.data_queue.qsize() == 0 + assert not ack_future.done() @pytest.mark.asyncio -async def test_wait_for_ack_leadmessage_instead(caplog: pytest.LogCaptureFixture) -> None: - """Test that LeadMessage with matching ID is accepted (ACK was dropped).""" +async def test_wait_for_ack_leadmessage_instead() -> None: + """Test that a LeadMessage with a matching ID goes to the receive queue, not the ACK future.""" + loop = asyncio.get_running_loop() protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Put a LeadMessage with matching ID (simulating dropped ACK) - leadmsg = OmniLogicMessage(123, MessageType.MSP_LEADMESSAGE) - await protocol.data_queue.put(leadmsg) + ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future() + protocol._ack_futures[123] = ack_future - with caplog.at_level("DEBUG"): - await protocol._wait_for_ack(123) + # Build a valid LeadMessage datagram and deliver it via datagram_received + leadmsg_payload = ( + '' + "LeadMessage" + '1003' + '10' + '1' + '0' + "" + ) + leadmsg = OmniLogicMessage(123, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) + protocol.datagram_received(bytes(leadmsg), ("127.0.0.1", 12345)) - # With matching ID, it's treated as the ACK we're looking for - assert any("Received ACK for message ID 123" in r.message for r in caplog.records) - # LeadMessage should NOT be in queue since IDs matched - assert protocol.data_queue.qsize() == 0 + # LeadMessages are not ACKs — the future should remain unresolved + assert not ack_future.done() + # The LeadMessage is placed on the receive queue for response assembly + assert protocol._receive_queue.qsize() == 1 @pytest.mark.asyncio -async def test_wait_for_ack_leadmessage_wrong_id(caplog: pytest.LogCaptureFixture) -> None: - """Test that LeadMessage with wrong ID is put back in queue and waiting continues.""" +async def test_wait_for_ack_leadmessage_wrong_id() -> None: + """Test that a LeadMessage with a different ID also goes to the receive queue, not the ACK future.""" + loop = asyncio.get_running_loop() protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Put a LeadMessage with wrong ID, then correct ACK - leadmsg = OmniLogicMessage(999, MessageType.MSP_LEADMESSAGE) - correct_ack = OmniLogicMessage(123, MessageType.ACK) + ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future() + protocol._ack_futures[123] = ack_future - await protocol.data_queue.put(leadmsg) - await protocol.data_queue.put(correct_ack) - - with caplog.at_level("DEBUG"): - await protocol._wait_for_ack(123) + leadmsg_payload = ( + '' + "LeadMessage" + '1003' + '10' + '1' + '0' + "" + ) + leadmsg = OmniLogicMessage(999, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) + protocol.datagram_received(bytes(leadmsg), ("127.0.0.1", 12345)) - # Should log that ACK was dropped and put LeadMessage back - assert any("ACK was dropped" in r.message for r in caplog.records) - # Both messages were consumed and LeadMessage was put back, so queue should have 1 item - # But the ACK was also consumed, so we actually end up with just the LeadMessage back - # Actually, looking at the code: LeadMessage gets put back, then we return - # So BOTH the correct ACK and the LeadMessage should be in the queue - assert protocol.data_queue.qsize() == 2 # LeadMessage put back, correct ACK also still there + # Future for ID 123 should remain unresolved + assert not ack_future.done() + # The LeadMessage goes to the receive queue regardless of ID + assert protocol._receive_queue.qsize() == 1 @pytest.mark.asyncio async def test_wait_for_ack_error_in_queue() -> None: - """Test that errors from error queue are raised.""" + """Test that _send_with_retry raises OmniTimeoutError when no ACK is received.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - - test_error = RuntimeError("Test error") - await protocol.error_queue.put(test_error) + protocol._transport = MagicMock() - with pytest.raises(RuntimeError, match="Test error"): - await protocol._wait_for_ack(123) + # Patch wait_for to always raise TimeoutError so every attempt times out immediately + with patch("asyncio.wait_for", side_effect=TimeoutError), pytest.raises(OmniTimeoutError): + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") # ============================================================================ @@ -387,64 +385,61 @@ async def test_wait_for_ack_error_in_queue() -> None: @pytest.mark.asyncio async def test_ensure_sent_ack_message() -> None: - """Test that ACK messages don't wait for ACK.""" + """Test that _send_xml_ack transmits directly without waiting for an ACK.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() - ack_message = OmniLogicMessage(123, MessageType.ACK) - - # Should return immediately without waiting - await protocol._ensure_sent(ack_message) + protocol._send_xml_ack(123) - protocol.transport.sendto.assert_called_once() + protocol._transport.sendto.assert_called_once() @pytest.mark.asyncio async def test_ensure_sent_xml_ack_message() -> None: - """Test that XML_ACK messages don't wait for ACK.""" + """Test that _send_xml_ack sends the correct XML_ACK message format.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - - xml_ack_message = OmniLogicMessage(123, MessageType.XML_ACK, payload="") + protocol._transport = MagicMock() - await protocol._ensure_sent(xml_ack_message) + protocol._send_xml_ack(456) - protocol.transport.sendto.assert_called_once() + protocol._transport.sendto.assert_called_once() + sent_bytes = protocol._transport.sendto.call_args[0][0] + # Verify the sent bytes parse back to an XML_ACK message with the correct ID + parsed = OmniLogicMessage.from_bytes(sent_bytes) + assert parsed.type == MessageType.XML_ACK + assert parsed.id == 456 @pytest.mark.asyncio async def test_ensure_sent_success_first_attempt() -> None: - """Test successful send on first attempt.""" + """Test successful send on first attempt with _send_with_retry.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() - # Mock _wait_for_ack to succeed immediately - with patch.object(protocol, "_wait_for_ack", new_callable=AsyncMock) as mock_wait: - message = OmniLogicMessage(123, MessageType.REQUEST_CONFIGURATION) - await protocol._ensure_sent(message, max_attempts=3) + # Simulate an immediate ACK by resolving the future when sendto is called + def resolve_ack_on_send(data: bytes) -> None: + msg = OmniLogicMessage.from_bytes(data) + protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK)) - protocol.transport.sendto.assert_called_once() - mock_wait.assert_called_once_with(123) + protocol._transport.sendto.side_effect = resolve_ack_on_send + + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") + + protocol._transport.sendto.assert_called_once() @pytest.mark.asyncio async def test_ensure_sent_timeout_and_retry_logs(caplog: pytest.LogCaptureFixture) -> None: - """Test that _ensure_sent logs retries and raises on repeated timeout.""" + """Test that _send_with_retry logs retries and raises OmniTimeoutError after all attempts.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() - async def always_timeout(*args: object, **kwargs: object) -> None: # noqa: ARG001 - await asyncio.sleep(0) - raise TimeoutError + with patch("asyncio.wait_for", side_effect=TimeoutError), caplog.at_level("DEBUG"), pytest.raises(OmniTimeoutError): + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") - message = OmniLogicMessage(123, MessageType.REQUEST_CONFIGURATION) - with patch.object(protocol, "_wait_for_ack", always_timeout), caplog.at_level("WARNING"), pytest.raises(OmniTimeoutError): - await protocol._ensure_sent(message, max_attempts=3) - - assert any("attempt 1/3" in r.message for r in caplog.records) - assert any("attempt 2/3" in r.message for r in caplog.records) - assert any("after 3 attempts" in r.message for r in caplog.records) - assert protocol.transport.sendto.call_count == 3 + retry_logs = [r for r in caplog.records if "no ACK" in r.message] + assert len(retry_logs) == OMNI_RETRANSMIT_COUNT + assert protocol._transport.sendto.call_count == OMNI_RETRANSMIT_COUNT + 1 # ============================================================================ @@ -454,47 +449,63 @@ async def always_timeout(*args: object, **kwargs: object) -> None: # noqa: ARG0 @pytest.mark.asyncio async def test_send_message_generates_random_id() -> None: - """Test that send_message generates a random ID when none provided.""" + """Test that _send_with_retry generates a non-zero message ID.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() + + captured_ids: list[int] = [] - with patch.object(protocol, "_ensure_sent", new_callable=AsyncMock) as mock_ensure: - await protocol.send_message(MessageType.REQUEST_CONFIGURATION, None, msg_id=None) + def capture_id(data: bytes) -> None: + msg = OmniLogicMessage.from_bytes(data) + captured_ids.append(msg.id) + protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK)) - mock_ensure.assert_called_once() - sent_message = mock_ensure.call_args[0][0] - assert sent_message.id != 0 # Should have a random ID + protocol._transport.sendto.side_effect = capture_id + + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") + + assert len(captured_ids) == 1 + assert captured_ids[0] != 0 @pytest.mark.asyncio -async def test_send_message_uses_provided_id() -> None: - """Test that send_message uses provided ID.""" +async def test_send_message_uses_incrementing_id() -> None: + """Test that each _send_with_retry call uses a distinct, incrementing message ID.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() + + captured_ids: list[int] = [] - with patch.object(protocol, "_ensure_sent", new_callable=AsyncMock) as mock_ensure: - await protocol.send_message(MessageType.REQUEST_CONFIGURATION, None, msg_id=12345) + def capture_and_ack(data: bytes) -> None: + msg = OmniLogicMessage.from_bytes(data) + captured_ids.append(msg.id) + protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK)) - mock_ensure.assert_called_once() - sent_message = mock_ensure.call_args[0][0] - assert sent_message.id == 12345 + protocol._transport.sendto.side_effect = capture_and_ack + + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") + + assert len(captured_ids) == 2 + assert captured_ids[0] != captured_ids[1] @pytest.mark.asyncio async def test_send_and_receive() -> None: - """Test send_and_receive calls send_message and _receive_file.""" + """Test async_send_and_receive calls _send_with_retry and _receive_response.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() with ( - patch.object(protocol, "send_message", new_callable=AsyncMock) as mock_send, - patch.object(protocol, "_receive_file", new_callable=AsyncMock) as mock_receive, + patch.object(protocol, "_send_with_retry", new_callable=AsyncMock) as mock_send, + patch.object(protocol, "_receive_response", new_callable=AsyncMock) as mock_receive, + patch.object(protocol, "_decode_payload", return_value="test response"), ): - mock_receive.return_value = "test response" + mock_receive.return_value = (b"raw", False) - result = await protocol.send_and_receive(MessageType.REQUEST_CONFIGURATION, "payload", 123) + result = await protocol.async_send_and_receive(MessageType.REQUEST_CONFIGURATION, "") - mock_send.assert_called_once_with(MessageType.REQUEST_CONFIGURATION, "payload", 123) + mock_send.assert_called_once_with(MessageType.REQUEST_CONFIGURATION, "") mock_receive.assert_called_once() assert result == "test response" @@ -506,22 +517,22 @@ async def test_send_and_receive() -> None: @pytest.mark.asyncio async def test_send_ack_generates_xml() -> None: - """Test that _send_ack generates proper XML ACK message.""" + """Test that _send_xml_ack transmits a well-formed XML_ACK message.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() + protocol._transport = MagicMock() - with patch.object(protocol, "send_message", new_callable=AsyncMock) as mock_send: - await protocol._send_ack(12345) + protocol._send_xml_ack(12345) - mock_send.assert_called_once() - call_args = mock_send.call_args - assert call_args[0][0] == MessageType.XML_ACK - assert call_args[0][2] == 12345 + protocol._transport.sendto.assert_called_once() + sent_bytes = protocol._transport.sendto.call_args[0][0] - # Verify XML structure - xml_payload = call_args[0][1] + parsed = OmniLogicMessage.from_bytes(sent_bytes) + assert parsed.type == MessageType.XML_ACK + assert parsed.id == 12345 + + # Verify XML structure contains the expected Ack name element + xml_payload = parsed.payload.rstrip(b"\x00").decode("utf-8") root = ET.fromstring(xml_payload) - assert "Request" in root.tag name_elem = root.find(".//{http://nextgen.hayward.com/api}Name") assert name_elem is not None assert name_elem.text == "Ack" @@ -534,81 +545,62 @@ async def test_send_ack_generates_xml() -> None: @pytest.mark.asyncio async def test_receive_file_simple_response() -> None: - """Test receiving a simple (non-fragmented) response.""" + """Test receiving a simple (non-fragmented) response via _receive_response.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Create a simple response message - response_msg = OmniLogicMessage(123, MessageType.GET_TELEMETRY, payload="") - await protocol.data_queue.put(response_msg) + response_msg = OmniLogicMessage(123, MessageType.MSP_TELEMETRY_UPDATE) + response_msg.payload = b"" + response_msg.compressed = False + await protocol._receive_queue.put(response_msg) - with patch.object(protocol, "_send_ack", new_callable=AsyncMock) as mock_ack: - result = await protocol._receive_file() + raw_data, compressed = await protocol._receive_response() - mock_ack.assert_called_once_with(123) - assert result == "" + assert raw_data == b"" + assert compressed is False @pytest.mark.asyncio async def test_receive_file_skips_duplicate_acks(caplog: pytest.LogCaptureFixture) -> None: - """Test that duplicate ACKs are skipped.""" + """Test that orphaned block messages are skipped and the real response is returned.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Put duplicate ACKs followed by real message - ack1 = OmniLogicMessage(111, MessageType.ACK) - ack2 = OmniLogicMessage(222, MessageType.XML_ACK) - response = OmniLogicMessage(333, MessageType.GET_TELEMETRY, payload="") + # Stray block message (no preceding lead message) should be skipped + stray_block = OmniLogicMessage(111, MessageType.MSP_BLOCKMESSAGE) + stray_block.payload = b"\x00" * 8 + b"stray" + response = OmniLogicMessage(333, MessageType.MSP_CONFIGURATIONUPDATE) + response.payload = b"" + response.compressed = False - await protocol.data_queue.put(ack1) - await protocol.data_queue.put(ack2) - await protocol.data_queue.put(response) + await protocol._receive_queue.put(stray_block) + await protocol._receive_queue.put(response) - with patch.object(protocol, "_send_ack", new_callable=AsyncMock), caplog.at_level("DEBUG"): - result = await protocol._receive_file() + with caplog.at_level("WARNING"): + raw_data, _compressed = await protocol._receive_response() - assert any("Skipping duplicate ACK" in r.message for r in caplog.records) - assert result == "" + assert any("block message" in r.message for r in caplog.records) + assert raw_data == b"" @pytest.mark.asyncio async def test_receive_file_decompresses_data() -> None: - """Test that compressed responses are decompressed.""" + """Test that _decode_payload decompresses compressed responses.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Create compressed payload original = b"This is test data that will be compressed" - compressed = zlib.compress(original) - - # Create message with compressed payload - response_msg = OmniLogicMessage(123, MessageType.GET_TELEMETRY) - response_msg.compressed = True - response_msg.payload = compressed + compressed_data = zlib.compress(original) - await protocol.data_queue.put(response_msg) - - with patch.object(protocol, "_send_ack", new_callable=AsyncMock): - result = await protocol._receive_file() + result = protocol._decode_payload(compressed_data, compressed=True) assert result == original.decode("utf-8") @pytest.mark.asyncio async def test_receive_file_decompression_error() -> None: - """Test that decompression errors are handled.""" + """Test that _decode_payload raises OmniMessageFormatError for invalid compressed data.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - - # Create message with invalid compressed data - response_msg = OmniLogicMessage(123, MessageType.GET_TELEMETRY) - response_msg.compressed = True - response_msg.payload = b"invalid compressed data" - - await protocol.data_queue.put(response_msg) - with patch.object(protocol, "_send_ack", new_callable=AsyncMock), pytest.raises(OmniMessageFormatError, match="Failed to decompress"): - await protocol._receive_file() + with pytest.raises(zlib.error): + protocol._decode_payload(b"invalid compressed data", compressed=True) # ============================================================================ @@ -618,11 +610,9 @@ async def test_receive_file_decompression_error() -> None: @pytest.mark.asyncio async def test_receive_file_fragmented_response() -> None: - """Test receiving a fragmented response with LeadMessage and BlockMessages.""" + """Test receiving a fragmented response with LeadMessage and BlockMessages via _receive_response.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Create LeadMessage leadmsg_payload = ( 'LeadMessage' '100324' @@ -631,30 +621,26 @@ async def test_receive_file_fragmented_response() -> None: ) leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) - # Create BlockMessages with 8-byte header block1 = OmniLogicMessage(101, MessageType.MSP_BLOCKMESSAGE) block1.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"first_part" block2 = OmniLogicMessage(102, MessageType.MSP_BLOCKMESSAGE) block2.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"second_part" - await protocol.data_queue.put(leadmsg) - await protocol.data_queue.put(block1) - await protocol.data_queue.put(block2) + await protocol._receive_queue.put(leadmsg) + await protocol._receive_queue.put(block1) + await protocol._receive_queue.put(block2) - with patch.object(protocol, "_send_ack", new_callable=AsyncMock) as mock_ack: - result = await protocol._receive_file() + raw_data, compressed = await protocol._receive_response() - # Should send ACK for LeadMessage and each BlockMessage - assert mock_ack.call_count == 3 - assert result == "first_partsecond_part" + assert raw_data == b"first_partsecond_part" + assert compressed is False @pytest.mark.asyncio async def test_receive_file_fragmented_out_of_order() -> None: - """Test that fragments received out of order are reassembled correctly.""" + """Test that fragments are concatenated in receive order.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() leadmsg_payload = ( 'LeadMessage' @@ -664,7 +650,7 @@ async def test_receive_file_fragmented_out_of_order() -> None: ) leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) - # Create blocks out of order (IDs: 102, 100, 101) + # Enqueue blocks out of ID order: 102, 100, 101 block2 = OmniLogicMessage(102, MessageType.MSP_BLOCKMESSAGE) block2.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"third" @@ -674,40 +660,33 @@ async def test_receive_file_fragmented_out_of_order() -> None: block1 = OmniLogicMessage(101, MessageType.MSP_BLOCKMESSAGE) block1.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"second" - await protocol.data_queue.put(leadmsg) - await protocol.data_queue.put(block2) # Out of order - await protocol.data_queue.put(block0) - await protocol.data_queue.put(block1) + await protocol._receive_queue.put(leadmsg) + await protocol._receive_queue.put(block2) + await protocol._receive_queue.put(block0) + await protocol._receive_queue.put(block1) - with patch.object(protocol, "_send_ack", new_callable=AsyncMock): - result = await protocol._receive_file() + raw_data, _compressed = await protocol._receive_response() - # Should be reassembled in ID order - assert result == "firstsecondthird" + # Blocks are assembled in receive order, not by ID + assert raw_data == b"thirdfirstsecond" @pytest.mark.asyncio async def test_receive_file_fragmented_invalid_leadmessage() -> None: - """Test that invalid LeadMessage XML raises error.""" + """Test that invalid LeadMessage XML raises a parse error.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - # Create LeadMessage with invalid XML leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload="invalid xml") - await protocol.data_queue.put(leadmsg) + await protocol._receive_queue.put(leadmsg) - with ( - patch.object(protocol, "_send_ack", new_callable=AsyncMock), - pytest.raises(OmniFragmentationError, match="Failed to parse LeadMessage"), - ): - await protocol._receive_file() + with pytest.raises(ET.ParseError): + await protocol._receive_response() @pytest.mark.asyncio async def test_receive_file_fragmented_timeout_waiting() -> None: - """Test timeout while waiting for fragments.""" + """Test that OmniTimeoutError is raised when no fragment arrives within the timeout.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() leadmsg_payload = ( 'LeadMessage' @@ -716,46 +695,27 @@ async def test_receive_file_fragmented_timeout_waiting() -> None: "" ) leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) + await protocol._receive_queue.put(leadmsg) + # No block messages — _receive_next_message will time out - await protocol.data_queue.put(leadmsg) - # Don't put any BlockMessages - will timeout - - with ( - patch.object(protocol, "_send_ack", new_callable=AsyncMock), - pytest.raises(OmniFragmentationError, match="Timeout receiving fragment"), - ): - await protocol._receive_file() + with pytest.raises(OmniTimeoutError): + await protocol._receive_response() @pytest.mark.asyncio async def test_receive_file_fragmented_max_wait_time_exceeded() -> None: - """Test that MAX_FRAGMENT_WAIT_TIME timeout is enforced.""" + """Test that _receive_next_message raises OmniTimeoutError when MAX_FRAGMENT_WAIT_TIME elapses.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - leadmsg_payload = ( - 'LeadMessage' - '100324' - '20' - "" - ) - leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) - - await protocol.data_queue.put(leadmsg) - - # Mock time to simulate timeout - with patch.object(protocol, "_send_ack", new_callable=AsyncMock), patch("time.time") as mock_time: - mock_time.side_effect = [0, 31] # Start at 0, then 31 seconds later (> 30s max) - - with pytest.raises(OmniFragmentationError, match="Timeout waiting for fragments"): - await protocol._receive_file() + # Patch asyncio.timeout to raise TimeoutError immediately, simulating the deadline being exceeded + with patch("asyncio.timeout", side_effect=TimeoutError), pytest.raises(OmniTimeoutError): + await protocol._receive_next_message() @pytest.mark.asyncio async def test_receive_file_fragmented_ignores_non_block_messages(caplog: pytest.LogCaptureFixture) -> None: - """Test that non-BlockMessages during fragmentation are ignored.""" + """Test that non-BlockMessages during fragment reassembly are warned and ignored.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() leadmsg_payload = ( 'LeadMessage' @@ -765,55 +725,37 @@ async def test_receive_file_fragmented_ignores_non_block_messages(caplog: pytest ) leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload) - # Put LeadMessage, then an ACK (should be ignored), then the actual block - ack_msg = OmniLogicMessage(999, MessageType.ACK) + # A configuration update message (not a block) should be skipped during reassembly + interloper = OmniLogicMessage(999, MessageType.MSP_CONFIGURATIONUPDATE) + interloper.payload = b"" + block1 = OmniLogicMessage(101, MessageType.MSP_BLOCKMESSAGE) block1.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"data" - await protocol.data_queue.put(leadmsg) - await protocol.data_queue.put(ack_msg) - await protocol.data_queue.put(block1) + await protocol._receive_queue.put(leadmsg) + await protocol._receive_queue.put(interloper) + await protocol._receive_queue.put(block1) - with patch.object(protocol, "_send_ack", new_callable=AsyncMock), caplog.at_level("DEBUG"): - result = await protocol._receive_file() + with caplog.at_level("WARNING"): + raw_data, _compressed = await protocol._receive_response() - assert any("other than a blockmessage" in r.message for r in caplog.records) - assert result == "data" + assert any("expected block message" in r.message for r in caplog.records) + assert raw_data == b"data" @pytest.mark.asyncio async def test_wait_for_ack_cancels_pending_tasks() -> None: - """Test that pending tasks are properly cancelled in _wait_for_ack to avoid warnings.""" + """Test that _ack_futures are cleaned up after _send_with_retry completes.""" protocol = OmniLogicProtocol() - protocol.transport = MagicMock() - - # Track tasks created during _wait_for_ack - created_tasks: list[asyncio.Task[Any]] = [] - original_create_task = asyncio.create_task - - def track_create_task(coro: Any) -> asyncio.Task[Any]: - task: asyncio.Task[Any] = original_create_task(coro) - created_tasks.append(task) - return task - - # Queue up an ACK message - ack_msg = OmniLogicMessage(42, MessageType.ACK) - await protocol.data_queue.put(ack_msg) - - # Patch create_task to track tasks - with patch("asyncio.create_task", side_effect=track_create_task): - await protocol._wait_for_ack(42) + protocol._transport = MagicMock() - # Give the event loop a chance to process cancellation - await asyncio.sleep(0) + def resolve_ack_on_send(data: bytes) -> None: + msg = OmniLogicMessage.from_bytes(data) + protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK)) - # Should have created 2 tasks (data_task and error_task) - assert len(created_tasks) == 2 + protocol._transport.sendto.side_effect = resolve_ack_on_send - # One should be done (the data_task that got the ACK) - # One should be cancelled (the error_task that was waiting) - done_tasks = [t for t in created_tasks if t.done() and not t.cancelled()] - cancelled_tasks = [t for t in created_tasks if t.cancelled()] + await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "") - assert len(done_tasks) == 1, "Expected exactly one task to complete normally" - assert len(cancelled_tasks) == 1, "Expected exactly one task to be cancelled" + # After a successful send, all futures should be cleaned up from _ack_futures + assert len(protocol._ack_futures) == 0