diff --git a/pyomnilogic_local/api/__init__.py b/pyomnilogic_local/api/__init__.py
index b9f9ff1..928fccb 100644
--- a/pyomnilogic_local/api/__init__.py
+++ b/pyomnilogic_local/api/__init__.py
@@ -8,7 +8,29 @@
from __future__ import annotations
from .api import OmniLogicAPI
+from .exceptions import (
+ OmniCommandError,
+ OmniConnectionError,
+ OmniFragmentationError,
+ OmniLogicError,
+ OmniMessageFormatError,
+ OmniProtocolError,
+ OmniTimeoutError,
+ OmniValidationError,
+)
+from .message import OmniLogicMessage
+from .protocol import OmniLogicProtocol
__all__ = [
+ "OmniCommandError",
+ "OmniConnectionError",
+ "OmniFragmentationError",
"OmniLogicAPI",
+ "OmniLogicError",
+ "OmniLogicMessage",
+ "OmniLogicProtocol",
+ "OmniMessageFormatError",
+ "OmniProtocolError",
+ "OmniTimeoutError",
+ "OmniValidationError",
]
diff --git a/pyomnilogic_local/api/api.py b/pyomnilogic_local/api/api.py
index 6e6f399..8834de0 100644
--- a/pyomnilogic_local/api/api.py
+++ b/pyomnilogic_local/api/api.py
@@ -111,30 +111,39 @@ def __init__(
self.controller_port = controller_port
self.response_timeout = response_timeout
- @overload
- async def async_send_message(self, message_type: MessageType, message: str | None, need_response: Literal[True]) -> str: ...
- @overload
- async def async_send_message(self, message_type: MessageType, message: str | None, need_response: Literal[False]) -> None: ...
- async def async_send_message(self, message_type: MessageType, message: str | None, need_response: bool = False) -> str | None:
+ async def async_send(self, message_type: MessageType, message: str) -> None:
"""Send a message via the Hayward Omni UDP protocol along with properly handling timeouts and responses.
Args:
message_type (MessageType): A selection from MessageType indicating what type of communication you are sending
- message (str | None): The XML body of the message to deliver
- need_response (bool, optional): Should a response be received and returned to the caller. Defaults to False.
+ message (str): The XML body of the message to deliver
+
+ Returns:
+ None
+ """
+ loop = asyncio.get_running_loop()
+ transport, protocol = await loop.create_datagram_endpoint(OmniLogicProtocol, remote_addr=(self.controller_ip, self.controller_port))
+
+ try:
+ await protocol.async_send(message_type, message)
+ finally:
+ transport.close()
+
+ async def async_send_and_receive(self, message_type: MessageType, message: str) -> str:
+ """Convenience method to send a message and receive a response without needing to specify need_response every time.
+
+ Args:
+ message_type (MessageType): A selection from MessageType indicating what type of communication you are sending
+ message (str): The message payload to send.
Returns:
- str | None: The response body sent from the Omni if need_response indicates that a response will be sent
+ str: The response body sent from the Omni
"""
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_datagram_endpoint(OmniLogicProtocol, remote_addr=(self.controller_ip, self.controller_port))
- resp: str | None = None
try:
- if need_response:
- resp = await protocol.send_and_receive(message_type, message, response_timeout=self.response_timeout)
- else:
- await protocol.send_message(message_type, message)
+ resp = await protocol.async_send_and_receive(message_type, message)
finally:
transport.close()
@@ -164,7 +173,7 @@ async def async_get_mspconfig(self, raw: bool = False) -> MSPConfig | str:
_LOGGER.debug("Sending RequestConfiguration with body: %s", req_body)
- resp = await self.async_send_message(MessageType.REQUEST_CONFIGURATION, req_body, True)
+ resp = await self.async_send_and_receive(MessageType.REQUEST_CONFIGURATION, req_body)
_LOGGER.debug("Received response for RequestConfiguration: %s", resp)
@@ -206,7 +215,7 @@ async def async_get_filter_diagnostics(self, pool_id: int, equipment_id: int, ra
_LOGGER.debug("Sending GetUIFilterDiagnosticInfo with body: %s", req_body)
- resp = await self.async_send_message(MessageType.GET_FILTER_DIAGNOSTIC_INFO, req_body, True)
+ resp = await self.async_send_and_receive(MessageType.GET_FILTER_DIAGNOSTIC_INFO, req_body)
_LOGGER.debug("Received response for GetUIFilterDiagnosticInfo: %s", resp)
@@ -235,7 +244,7 @@ async def async_get_telemetry(self, raw: bool = False) -> Telemetry | str:
_LOGGER.debug("Sending RequestTelemetryData with body: %s", req_body)
- resp = await self.async_send_message(MessageType.GET_TELEMETRY, req_body, True)
+ resp = await self.async_send_and_receive(MessageType.GET_TELEMETRY, req_body)
_LOGGER.debug("Received response for RequestTelemetryData: %s", resp)
@@ -276,7 +285,7 @@ async def async_set_heater(
_LOGGER.debug("Sending SetUIHeaterCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_HEATER_COMMAND, req_body, False)
+ return await self.async_send(MessageType.SET_HEATER_COMMAND, req_body)
async def async_set_solar_heater(
self,
@@ -311,7 +320,7 @@ async def async_set_solar_heater(
_LOGGER.debug("Sending SetUISolarSetPointCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_SOLAR_SET_POINT_COMMAND, req_body, False)
+ return await self.async_send(MessageType.SET_SOLAR_SET_POINT_COMMAND, req_body)
async def async_set_heater_mode(
self,
@@ -346,7 +355,7 @@ async def async_set_heater_mode(
_LOGGER.debug("Sending SetUIHeaterModeCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_HEATER_MODE_COMMAND, req_body, False)
+ return await self.async_send(MessageType.SET_HEATER_MODE_COMMAND, req_body)
async def async_set_heater_enable(
self,
@@ -381,7 +390,7 @@ async def async_set_heater_enable(
_LOGGER.debug("Sending SetHeaterEnable with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_HEATER_ENABLED, req_body, False)
+ return await self.async_send(MessageType.SET_HEATER_ENABLED, req_body)
async def async_set_equipment(
self,
@@ -443,7 +452,7 @@ async def async_set_equipment(
_LOGGER.debug("Sending SetUIEquipmentCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_EQUIPMENT, req_body, False)
+ return await self.async_send(MessageType.SET_EQUIPMENT, req_body)
async def async_set_filter_speed(self, pool_id: int, equipment_id: int, speed: int) -> None:
"""Set the speed for a variable speed filter/pump.
@@ -471,7 +480,7 @@ async def async_set_filter_speed(self, pool_id: int, equipment_id: int, speed: i
_LOGGER.debug("Sending SetUIFilterSpeedCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_FILTER_SPEED, req_body, False)
+ return await self.async_send(MessageType.SET_FILTER_SPEED, req_body)
async def async_set_light_show(
self,
@@ -543,7 +552,7 @@ async def async_set_light_show(
_LOGGER.debug("Sending SetStandAloneLightShow with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_STANDALONE_LIGHT_SHOW, req_body, False)
+ return await self.async_send(MessageType.SET_STANDALONE_LIGHT_SHOW, req_body)
async def async_set_chlorinator_enable(self, pool_id: int, enabled: int | bool) -> None:
body_element = ET.Element("Request", {"xmlns": XML_NAMESPACE})
@@ -561,7 +570,7 @@ async def async_set_chlorinator_enable(self, pool_id: int, enabled: int | bool)
_LOGGER.debug("Sending SetCHLOREnable with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_CHLOR_ENABLED, req_body, False)
+ return await self.async_send(MessageType.SET_CHLOR_ENABLED, req_body)
# This is used to set the ORP target value on a CSAD
async def async_set_csad_orp_target_level(
@@ -587,7 +596,7 @@ async def async_set_csad_orp_target_level(
_LOGGER.debug("Sending SetUICSADORPTargetLevel with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_CSAD_ORP_TARGET, req_body, False)
+ return await self.async_send(MessageType.SET_CSAD_ORP_TARGET, req_body)
# This is used to set the pH target value on a CSAD
async def async_set_csad_target_value(
@@ -613,7 +622,7 @@ async def async_set_csad_target_value(
_LOGGER.debug("Sending UISetCSADTargetValue with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_CSAD_TARGET_VALUE, req_body, False)
+ return await self.async_send(MessageType.SET_CSAD_TARGET_VALUE, req_body)
async def async_set_chlorinator_params(
self,
@@ -656,7 +665,7 @@ async def async_set_chlorinator_params(
_LOGGER.debug("Sending SetCHLORParams with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_CHLOR_PARAMS, req_body, False)
+ return await self.async_send(MessageType.SET_CHLOR_PARAMS, req_body)
async def async_set_chlorinator_superchlorinate(
self,
@@ -681,7 +690,7 @@ async def async_set_chlorinator_superchlorinate(
_LOGGER.debug("Sending SetUISuperCHLORCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_SUPERCHLORINATE, req_body, False)
+ return await self.async_send(MessageType.SET_SUPERCHLORINATE, req_body)
async def async_restore_idle_state(self) -> None:
body_element = ET.Element("Request", {"xmlns": XML_NAMESPACE})
@@ -695,7 +704,7 @@ async def async_restore_idle_state(self) -> None:
_LOGGER.debug("Sending RestoreIdleState with body: %s", req_body)
- return await self.async_send_message(MessageType.RESTORE_IDLE_STATE, req_body, False)
+ return await self.async_send(MessageType.RESTORE_IDLE_STATE, req_body)
async def async_set_spillover(
self,
@@ -738,7 +747,7 @@ async def async_set_spillover(
_LOGGER.debug("Sending SetUISpilloverCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.SET_SPILLOVER, req_body, False)
+ return await self.async_send(MessageType.SET_SPILLOVER, req_body)
async def async_set_group_enable(
self,
@@ -781,7 +790,7 @@ async def async_set_group_enable(
_LOGGER.debug("Sending RunGroupCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.RUN_GROUP_CMD, req_body, False)
+ return await self.async_send(MessageType.RUN_GROUP_CMD, req_body)
async def async_edit_schedule(
self,
@@ -854,4 +863,4 @@ async def async_edit_schedule(
_LOGGER.debug("Sending EditUIScheduleCmd with body: %s", req_body)
- return await self.async_send_message(MessageType.EDIT_SCHEDULE, req_body, False)
+ return await self.async_send(MessageType.EDIT_SCHEDULE, req_body)
diff --git a/pyomnilogic_local/api/constants.py b/pyomnilogic_local/api/constants.py
index 5526a21..25a0626 100644
--- a/pyomnilogic_local/api/constants.py
+++ b/pyomnilogic_local/api/constants.py
@@ -11,9 +11,9 @@
BLOCK_MESSAGE_HEADER_OFFSET = 8 # Offset to skip block message header and get to payload
# Timing Constants (in seconds)
-OMNI_RETRANSMIT_TIME = 2.1 # Time Omni waits before retransmitting a packet
+OMNI_RETRANSMIT_TIME = 2 # Time Omni waits before retransmitting a packet
OMNI_RETRANSMIT_COUNT = 5 # Number of retransmit attempts (6 total including initial)
-ACK_WAIT_TIMEOUT = 1 # Timeout waiting for ACK response, 0.5 showed to be just a tad too short in some cases.
+ACK_WAIT_TIMEOUT = OMNI_RETRANSMIT_TIME * 2 # Timeout waiting for ACK response, 0.5 showed to be just a tad too short in some cases.
DEFAULT_RESPONSE_TIMEOUT = OMNI_RETRANSMIT_TIME * OMNI_RETRANSMIT_COUNT # Default timeout for receiving responses
# Network Constants
diff --git a/pyomnilogic_local/api/message.py b/pyomnilogic_local/api/message.py
new file mode 100644
index 0000000..a8fee1a
--- /dev/null
+++ b/pyomnilogic_local/api/message.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+import logging
+import struct
+import time
+from typing import Self
+
+from pyomnilogic_local.omnitypes import ClientType, MessageType
+
+from .constants import (
+ PROTOCOL_HEADER_FORMAT,
+ PROTOCOL_HEADER_SIZE,
+ PROTOCOL_VERSION,
+)
+from .exceptions import OmniMessageFormatError
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class OmniLogicMessage:
+ """A protocol message for communication with the OmniLogic controller.
+
+ Handles serialization and deserialization of message headers and payloads.
+ """
+
+ header_format = PROTOCOL_HEADER_FORMAT
+ id: int
+ type: MessageType
+ payload: bytes
+ client_type: ClientType = ClientType.SIMPLE
+ version: str = PROTOCOL_VERSION
+ timestamp: int
+ reserved_1: int = 0
+ compressed: bool = False
+ reserved_2: int = 0
+
+ def __init__(
+ self,
+ msg_id: int,
+ msg_type: MessageType,
+ payload: str | None = None,
+ version: str = PROTOCOL_VERSION,
+ timestamp: int | None = None,
+ ) -> None:
+ """Initialize a new OmniLogicMessage.
+
+ Args:
+ msg_id: Unique message identifier.
+ msg_type: Type of message being sent.
+ payload: Optional string payload (XML or command body).
+ version: Protocol version string.
+ timestamp: Optional timestamp for the message.
+ """
+ self.id = msg_id
+ self.type = msg_type
+ # If we are speaking the XML API, it seems like we need client_type 0, otherwise we need client_type 1
+ self.client_type = ClientType.XML if payload is not None else ClientType.SIMPLE
+ # The Hayward API terminates it's messages with a null character
+ payload = f"{payload}\x00" if payload is not None else ""
+ self.payload = bytes(payload, "utf-8")
+
+ self.version = version
+ self.timestamp = timestamp if timestamp is not None else int(time.time())
+
+ def __bytes__(self) -> bytes:
+ """Serialize the message to bytes for UDP transmission.
+
+ Returns:
+ Byte representation of the message.
+ """
+ header = struct.pack(
+ self.header_format,
+ self.id, # Msg id
+ self.timestamp,
+ bytes(self.version, "ascii"), # version string
+ self.type.value, # OpID/msgType
+ self.client_type.value, # Client type
+ 0, # reserved
+ self.compressed, # compressed
+ 0, # reserved
+ )
+ return header + self.payload
+
+ def __repr__(self) -> str:
+ """Return a string representation of the message for debugging."""
+ if self.compressed or self.type is MessageType.MSP_BLOCKMESSAGE:
+ return f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}"
+ return (
+ f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}, "
+ f"Body: {self.payload[:-1].decode('utf-8')}"
+ )
+
+ @classmethod
+ def from_bytes(cls, data: bytes) -> Self:
+ """Parse a message from its byte representation.
+
+ Args:
+ data: Byte data received from the controller.
+
+ Returns:
+ OmniLogicMessage instance.
+
+ Raises:
+ OmniMessageFormatException: If the message format is invalid.
+ """
+ if len(data) < PROTOCOL_HEADER_SIZE:
+ msg = f"Message too short: {len(data)} bytes, expected at least {PROTOCOL_HEADER_SIZE}"
+ raise OmniMessageFormatError(msg)
+
+ # split the header and data
+ header = data[:PROTOCOL_HEADER_SIZE]
+ rdata: bytes = data[PROTOCOL_HEADER_SIZE:]
+
+ try:
+ (msg_id, tstamp, vers, msg_type, client_type, res1, compressed, res2) = struct.unpack(cls.header_format, header)
+ except struct.error as exc:
+ msg = f"Failed to unpack message header: {exc}"
+ raise OmniMessageFormatError(msg) from exc
+
+ # Validate message type
+ try:
+ message_type_enum = MessageType(msg_type)
+ except ValueError as exc:
+ msg = f"Unknown message type: {msg_type}: {exc}"
+ raise OmniMessageFormatError(msg) from exc
+
+ # Validate client type
+ try:
+ client_type_enum = ClientType(int(client_type))
+ except ValueError as exc:
+ msg = f"Unknown client type: {client_type}: {exc}"
+ raise OmniMessageFormatError(msg) from exc
+
+ message = cls(msg_id=msg_id, msg_type=message_type_enum, version=vers.decode("utf-8"))
+ message.timestamp = tstamp
+ message.client_type = client_type_enum
+ message.reserved_1 = res1
+ # There are some messages that are ALWAYS compressed although they do not return a 1 in their LeadMessage
+ message.compressed = compressed == 1 or message.type in [MessageType.MSP_TELEMETRY_UPDATE]
+ message.reserved_2 = res2
+ message.payload = rdata
+
+ return message
diff --git a/pyomnilogic_local/api/protocol.py b/pyomnilogic_local/api/protocol.py
index ba31a96..c816b85 100644
--- a/pyomnilogic_local/api/protocol.py
+++ b/pyomnilogic_local/api/protocol.py
@@ -1,16 +1,16 @@
+"""Asyncio UDP datagram protocol for communication with the OmniLogic controller."""
+
from __future__ import annotations
import asyncio
import logging
import random
-import struct
-import time
import xml.etree.ElementTree as ET
import zlib
-from typing import Any, Self, cast
+from typing import cast
from pyomnilogic_local.models.leadmessage import LeadMessage
-from pyomnilogic_local.omnitypes import ClientType, MessageType
+from pyomnilogic_local.omnitypes import MessageType
from .constants import (
ACK_WAIT_TIMEOUT,
@@ -19,455 +19,330 @@
MAX_FRAGMENT_WAIT_TIME,
MAX_QUEUE_SIZE,
OMNI_RETRANSMIT_COUNT,
- OMNI_RETRANSMIT_TIME,
- PROTOCOL_HEADER_FORMAT,
- PROTOCOL_HEADER_SIZE,
- PROTOCOL_VERSION,
- XML_ENCODING,
XML_NAMESPACE,
)
-from .exceptions import OmniFragmentationError, OmniMessageFormatError, OmniTimeoutError
+from .exceptions import OmniConnectionError, OmniMessageFormatError, OmniTimeoutError
+from .message import OmniLogicMessage
_LOGGER = logging.getLogger(__name__)
+_ACK_PAYLOAD = f'\nAck\n'
-class OmniLogicMessage:
- """A protocol message for communication with the OmniLogic controller.
-
- Handles serialization and deserialization of message headers and payloads.
- """
+_ACK_TYPES = frozenset({MessageType.ACK, MessageType.XML_ACK})
- header_format = PROTOCOL_HEADER_FORMAT
- id: int
- type: MessageType
- payload: bytes
- client_type: ClientType = ClientType.SIMPLE
- version: str = PROTOCOL_VERSION
- timestamp: int
- reserved_1: int = 0
- compressed: bool = False
- reserved_2: int = 0
-
- def __init__(
- self,
- msg_id: int,
- msg_type: MessageType,
- payload: str | None = None,
- version: str = PROTOCOL_VERSION,
- timestamp: int | None = None,
- ) -> None:
- """Initialize a new OmniLogicMessage.
+# Type alias for items placed on the receive queue: either a parsed message or a parse error.
+_QueueItem = OmniLogicMessage | OmniMessageFormatError
- Args:
- msg_id: Unique message identifier.
- msg_type: Type of message being sent.
- payload: Optional string payload (XML or command body).
- version: Protocol version string.
- timestamp: Optional timestamp for the message.
- """
- self.id = msg_id
- self.type = msg_type
- # If we are speaking the XML API, it seems like we need client_type 0, otherwise we need client_type 1
- self.client_type = ClientType.XML if payload is not None else ClientType.SIMPLE
- # The Hayward API terminates it's messages with a null character
- payload = f"{payload}\x00" if payload is not None else ""
- self.payload = bytes(payload, "utf-8")
- self.version = version
- self.timestamp = timestamp if timestamp is not None else int(time.time())
+class OmniLogicProtocol(asyncio.DatagramProtocol):
+ """Asyncio UDP datagram protocol for communication with the OmniLogic controller.
- def __bytes__(self) -> bytes:
- """Serialize the message to bytes for UDP transmission.
+ Handles message framing, acknowledgement, retransmission, and multi-part
+ response reassembly for the Hayward OmniLogic local UDP protocol.
- Returns:
- Byte representation of the message.
- """
- header = struct.pack(
- self.header_format,
- self.id, # Msg id
- self.timestamp,
- bytes(self.version, "ascii"), # version string
- self.type.value, # OpID/msgType
- self.client_type.value, # Client type
- 0, # reserved
- self.compressed, # compressed
- 0, # reserved
- )
- return header + self.payload
-
- def __repr__(self) -> str:
- """Return a string representation of the message for debugging."""
- if self.compressed or self.type is MessageType.MSP_BLOCKMESSAGE:
- return f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}"
- return (
- f"ID: {self.id}, Type: {self.type.name}, Compressed: {self.compressed}, Client: {self.client_type.name}, "
- f"Body: {self.payload[:-1].decode('utf-8')}"
+ Example:
+ loop = asyncio.get_running_loop()
+ transport, protocol = await loop.create_datagram_endpoint(
+ OmniLogicProtocol, remote_addr=(controller_ip, controller_port)
)
+ try:
+ response = await protocol.async_send_and_receive(MessageType.GET_TELEMETRY, xml_body)
+ finally:
+ transport.close()
+ """
- @classmethod
- def from_bytes(cls, data: bytes) -> Self:
- """Parse a message from its byte representation.
+ def __init__(self) -> None:
+ self._transport: asyncio.DatagramTransport | None = None
+ # Seed with a random value so each protocol instance (one per API call) uses distinct IDs.
+ # Message ID is an unsigned 32-bit integer in the wire format, so cap at 2**16 to
+ # leave room for plenty of increments.
+ self._msg_counter: int = random.randint(1, 2**16)
+ self._ack_futures: dict[int, asyncio.Future[OmniLogicMessage]] = {}
+ self._receive_queue: asyncio.Queue[_QueueItem] = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
+
+ # -------------------------------------------------------------------------
+ # asyncio.DatagramProtocol callbacks
+ # -------------------------------------------------------------------------
- Args:
- data: Byte data received from the controller.
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ self._transport = cast("asyncio.DatagramTransport", transport)
+ _LOGGER.debug("connection established")
- Returns:
- OmniLogicMessage instance.
+ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
+ try:
+ msg = OmniLogicMessage.from_bytes(data)
+ except OmniMessageFormatError as exc:
+ _LOGGER.warning("received unparsable datagram from %s: %s", addr, exc)
+ self._receive_queue.put_nowait(exc)
+ return
- Raises:
- OmniMessageFormatException: If the message format is invalid.
- """
- if len(data) < PROTOCOL_HEADER_SIZE:
- msg = f"Message too short: {len(data)} bytes, expected at least {PROTOCOL_HEADER_SIZE}"
- raise OmniMessageFormatError(msg)
+ _LOGGER.debug("received from %s: %r", addr, msg)
- # split the header and data
- header = data[:PROTOCOL_HEADER_SIZE]
- rdata: bytes = data[PROTOCOL_HEADER_SIZE:]
+ if msg.type in _ACK_TYPES:
+ self._resolve_ack(msg)
+ else:
+ self._send_xml_ack(msg.id)
+ self._receive_queue.put_nowait(msg)
- try:
- (msg_id, tstamp, vers, msg_type, client_type, res1, compressed, res2) = struct.unpack(cls.header_format, header)
- except struct.error as exc:
- msg = f"Failed to unpack message header: {exc}"
- raise OmniMessageFormatError(msg) from exc
+ def error_received(self, exc: Exception) -> None:
+ _LOGGER.error("transport error: %s", exc)
- # Validate message type
- try:
- message_type_enum = MessageType(msg_type)
- except ValueError as exc:
- msg = f"Unknown message type: {msg_type}: {exc}"
- raise OmniMessageFormatError(msg) from exc
+ def connection_lost(self, exc: Exception | None) -> None:
+ _LOGGER.debug("connection lost: %s", exc)
+
+ # -------------------------------------------------------------------------
+ # Internal helpers
+ # -------------------------------------------------------------------------
+
+ def _next_msg_id(self) -> int:
+ """Return the next sequential message ID."""
+ self._msg_counter += 1
+ return self._msg_counter
+
+ def _resolve_ack(self, msg: OmniLogicMessage) -> None:
+ """Resolve the pending ACK future for the given message ID."""
+ future = self._ack_futures.get(msg.id)
+ if future is not None and not future.done():
+ future.set_result(msg)
+
+ def _send_xml_ack(self, msg_id: int) -> None:
+ """Transmit an XML_ACK for a received message."""
+ if self._transport is None:
+ _LOGGER.warning("cannot send ACK for ID %d, transport unavailable", msg_id)
+ return
+ ack = OmniLogicMessage(msg_id=msg_id, msg_type=MessageType.XML_ACK, payload=_ACK_PAYLOAD)
+ self._transport.sendto(bytes(ack))
+ _LOGGER.debug("sent XML_ACK for message ID %d", msg_id)
+
+ def _build_request(self, msg_type: MessageType, payload: str) -> OmniLogicMessage:
+ """Build a new outgoing request message with a fresh ID."""
+ return OmniLogicMessage(msg_id=self._next_msg_id(), msg_type=msg_type, payload=payload)
+
+ async def _send_with_retry(self, msg_type: MessageType, payload: str) -> None:
+ """Transmit a message and wait for acknowledgement, retransmitting as needed.
+
+ A fresh message ID is used on each attempt so the controller treats every
+ retransmission as a new request (an ACK only confirms receipt/parse, not
+ that the controller will act on or re-respond to the same ID again).
- # Validate client type
- try:
- client_type_enum = ClientType(int(client_type))
- except ValueError as exc:
- msg = f"Unknown client type: {client_type}: {exc}"
- raise OmniMessageFormatError(msg) from exc
+ Args:
+ msg_type: The type of message to send.
+ payload: The XML payload string.
- message = cls(msg_id=msg_id, msg_type=message_type_enum, version=vers.decode("utf-8"))
- message.timestamp = tstamp
- message.client_type = client_type_enum
- message.reserved_1 = res1
- # There are some messages that are ALWAYS compressed although they do not return a 1 in their LeadMessage
- message.compressed = compressed == 1 or message.type in [MessageType.MSP_TELEMETRY_UPDATE]
- message.reserved_2 = res2
- message.payload = rdata
+ Raises:
+ OmniConnectionError: If the transport is not available.
+ OmniTimeoutError: If no ACK is received after all retransmission attempts.
+ """
+ if self._transport is None:
+ msg = f"Cannot send message type {msg_type.name}, transport not available"
+ raise OmniConnectionError(msg)
- return message
+ loop = asyncio.get_running_loop()
+ for attempt in range(OMNI_RETRANSMIT_COUNT + 1):
+ message = self._build_request(msg_type, payload)
+ ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future()
+ self._ack_futures[message.id] = ack_future
-class OmniLogicProtocol(asyncio.DatagramProtocol):
- """Asyncio DatagramProtocol implementation for OmniLogic UDP communication.
+ try:
+ _LOGGER.debug(
+ "transmitting message ID: %d, type: %s (attempt %d/%d)",
+ message.id,
+ message.type.name,
+ attempt + 1,
+ OMNI_RETRANSMIT_COUNT + 1,
+ )
+ self._transport.sendto(bytes(message))
+ try:
+ await asyncio.wait_for(asyncio.shield(ack_future), timeout=ACK_WAIT_TIMEOUT)
+ except TimeoutError:
+ if attempt < OMNI_RETRANSMIT_COUNT:
+ _LOGGER.debug("no ACK for message ID %d, will retry", message.id)
+ else:
+ return
+ finally:
+ self._ack_futures.pop(message.id, None)
+ if not ack_future.done():
+ ack_future.cancel()
- Handles message sending, receiving, retries, and block message reassembly.
- """
+ msg = f"No ACK received for message type {msg_type.name} after {OMNI_RETRANSMIT_COUNT + 1} attempts"
+ raise OmniTimeoutError(msg)
- transport: asyncio.DatagramTransport
- # The omni will re-transmit a packet every 2 seconds if it does not receive an ACK. We pad that just a touch to be safe
- _omni_retransmit_time = OMNI_RETRANSMIT_TIME
- # The omni will re-transmit 5 times (a total of 6 attempts including the initial) if it does not receive an ACK
- _omni_retransmit_count = OMNI_RETRANSMIT_COUNT
+ async def _receive_next_message(self) -> OmniLogicMessage:
+ """Wait for and return the next incoming (non-ACK) message from the queue.
- data_queue: asyncio.Queue[OmniLogicMessage]
- error_queue: asyncio.Queue[Exception]
+ Raises:
+ OmniTimeoutError: If no message arrives within MAX_FRAGMENT_WAIT_TIME seconds.
+ OmniMessageFormatError: If the queued item is a parse error.
+ """
+ try:
+ async with asyncio.timeout(MAX_FRAGMENT_WAIT_TIME):
+ item: _QueueItem = await self._receive_queue.get()
+ except TimeoutError as exc:
+ msg = "Timed out waiting for response message from controller"
+ raise OmniTimeoutError(msg) from exc
- def __init__(self) -> None:
- """Initialize the protocol handler and message queue."""
- self.data_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
- self.error_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
+ if isinstance(item, OmniMessageFormatError):
+ raise item
- def connection_made(self, transport: asyncio.BaseTransport) -> None:
- """Called when a UDP connection is made."""
- self.transport = cast("asyncio.DatagramTransport", transport)
+ return item
- def connection_lost(self, exc: Exception | None) -> None:
- """Called when the UDP connection is lost or closed."""
- if exc:
- raise exc
+ # -------------------------------------------------------------------------
+ # Response assembly
+ # -------------------------------------------------------------------------
+
+ async def _receive_response(self) -> tuple[bytes, bool]:
+ """Receive and return the full response payload for a prior request.
- def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None:
- """Called when a datagram is received from the controller.
+ Handles both single-message responses and multi-part lead/block responses.
- Parses the message and puts it on the queue. Handles corrupt or unexpected data gracefully.
+ Returns:
+ Tuple of (raw_payload_bytes, compressed_flag).
"""
- try:
- message = OmniLogicMessage.from_bytes(data)
- _LOGGER.debug("Received Message %s from %s", str(message), addr)
- try:
- self.data_queue.put_nowait(message)
- except asyncio.QueueFull:
- _LOGGER.exception("Data queue is full. Dropping message: %s", str(message))
- except OmniMessageFormatError as exc:
- _LOGGER.exception("Failed to parse incoming datagram from %s", addr)
- self.error_queue.put_nowait(exc)
- except Exception as exc:
- _LOGGER.exception("Unexpected error processing datagram from %s", addr)
- self.error_queue.put_nowait(exc)
+ while True:
+ msg = await self._receive_next_message()
- def error_received(self, exc: Exception) -> None:
- """Called when a UDP error is received.
+ if msg.type == MessageType.MSP_LEADMESSAGE:
+ return await self._reassemble_multipart(msg)
- Store the error so it can be handled by awaiting coroutines.
- """
- self.error_queue.put_nowait(exc)
+ if msg.type == MessageType.MSP_BLOCKMESSAGE:
+ _LOGGER.warning("received block message ID %d before any lead message, ignoring", msg.id)
+ continue
- async def _wait_for_ack(self, ack_id: int) -> None:
- """Wait for an ACK message with the given ID.
+ return msg.payload, msg.compressed
- Handles dropped or out-of-order ACKs.
+ async def _reassemble_multipart(self, lead_msg: OmniLogicMessage) -> tuple[bytes, bool]:
+ """Reassemble a multi-part response from a lead message and its block messages.
Args:
- ack_id: The message ID to wait for an ACK.
+ lead_msg: The initial MSP_LEADMESSAGE received from the controller.
- Raises:
- OmniTimeoutException: If no ACK is received.
- Exception: If a protocol error occurs.
+ Returns:
+ Tuple of (concatenated_block_payload_bytes, compressed_flag).
"""
- # Wait for either an ACK message or an error
- # Race condition: datagram_received() calls put_nowait() synchronously, so data_task may
- # already be done when wait_for fires its timeout CancelledError. In that case we catch
- # the cancellation, skip re-looping, and fall through to check the result below. If the
- # result is our ACK we return normally, suppressing the CancelledError so wait_for treats
- # the call as successful. If it isn't, we re-raise after the loop.
- cancelled: asyncio.CancelledError | None = None
- retry = True
- while retry:
- # Wait for either a message or an error
- data_task = asyncio.create_task(self.data_queue.get())
- error_task = asyncio.create_task(self.error_queue.get())
- try:
- done, pending = await asyncio.wait([data_task, error_task], return_when=asyncio.FIRST_COMPLETED)
- except asyncio.CancelledError as exc:
- retry = False
- cancelled = exc
- done = {t for t in (data_task, error_task) if t.done()}
- pending = {t for t in (data_task, error_task) if not t.done()}
-
- # Cancel any pending tasks to avoid "Task was destroyed but it is pending" warnings
- for task in pending:
- task.cancel()
-
- if error_task in done:
- err = error_task.result()
- if isinstance(err, Exception):
- raise err
- _LOGGER.error("Unknown error occurred during communication with OmniLogic: %s", err)
- if data_task in done:
- message = data_task.result()
- if message.id == ack_id:
- _LOGGER.debug("Received ACK for message ID %s", ack_id)
- return
- _LOGGER.debug("We received a message that is not our ACK, it appears the ACK was dropped")
- if message.type in {MessageType.MSP_LEADMESSAGE, MessageType.MSP_TELEMETRY_UPDATE}:
- _LOGGER.debug("Omni has sent a new message, continuing on with the communication")
- await self.data_queue.put(message)
- return
+ lead = self._parse_lead_message(lead_msg)
+ compressed = lead_msg.compressed
+ seen_lead_ids: set[int] = {lead_msg.id}
+
+ _LOGGER.debug(
+ "reassembling %d-block response (compressed=%s)",
+ lead.msg_block_count,
+ compressed,
+ )
- if cancelled is not None:
- raise cancelled
+ payload_data = b""
+ received_block_count = 0
+ seen_block_ids: set[int] = set()
- async def _ensure_sent(
- self,
- message: OmniLogicMessage,
- max_attempts: int = 5,
- ) -> None:
- """Send a message and ensure it is acknowledged, retrying if necessary.
+ while received_block_count < lead.msg_block_count:
+ msg = await self._receive_next_message()
- Args:
- message: The message to send.
- max_attempts: Maximum number of send attempts.
+ if msg.type == MessageType.MSP_LEADMESSAGE:
+ if msg.id not in seen_lead_ids:
+ _LOGGER.warning("received unexpected secondary lead message ID %d", msg.id)
+ seen_lead_ids.add(msg.id)
+ else:
+ _LOGGER.debug("received duplicate lead message ID %d, re-ACK already sent", msg.id)
+ continue
- Raises:
- OmniTimeoutException: If no ACK is received after retries.
- """
- for attempt in range(max_attempts):
- self.transport.sendto(bytes(message))
- _LOGGER.debug("Sent message ID %s (attempt %d/%d)", message.id, attempt + 1, max_attempts)
+ if msg.type != MessageType.MSP_BLOCKMESSAGE:
+ _LOGGER.warning("expected block message but got %s (ID %d), ignoring", msg.type.name, msg.id)
+ continue
- # If the message that we just sent is an ACK, we do not need to wait to receive an ACK, we are done
- if message.type in [MessageType.XML_ACK, MessageType.ACK]:
- return
+ if msg.id in seen_block_ids:
+ _LOGGER.debug("received duplicate block message ID %d, re-ACK already sent", msg.id)
+ continue
- # Wait for a bit to either receive an ACK for our message, otherwise, we retry delivery
- try:
- await asyncio.wait_for(self._wait_for_ack(message.id), ACK_WAIT_TIMEOUT)
- except TimeoutError as exc:
- if attempt < max_attempts - 1:
- _LOGGER.warning(
- "ACK not received for message type %s (ID: %s), attempt %d/%d. Retrying...",
- message.type.name,
- message.id,
- attempt + 1,
- max_attempts,
- )
- else:
- _LOGGER.exception(
- "Failed to receive ACK for message type %s (ID: %s) after %d attempts.", message.type.name, message.id, max_attempts
- )
- msg = f"Failed to receive acknowledgement of command, max retries exceeded: {exc}"
- raise OmniTimeoutError(msg) from exc
- else:
- return
-
- async def send_and_receive(
- self,
- msg_type: MessageType,
- payload: str | None,
- msg_id: int | None = None,
- response_timeout: float = DEFAULT_RESPONSE_TIMEOUT,
- ) -> str:
- """Send a message and wait for a response, returning the response payload as a string.
+ seen_block_ids.add(msg.id)
+ payload_data += msg.payload[BLOCK_MESSAGE_HEADER_OFFSET:]
+ received_block_count += 1
+ _LOGGER.debug("received block %d/%d", received_block_count, lead.msg_block_count)
+
+ return payload_data, compressed
+
+ def _parse_lead_message(self, msg: OmniLogicMessage) -> LeadMessage:
+ """Parse the XML payload of an MSP_LEADMESSAGE into a LeadMessage model.
Args:
- msg_type: Type of message to send.
- payload: Optional payload string.
- msg_id: Optional message ID.
- response_timeout: Timeout in seconds to wait for the response.
+ msg: The MSP_LEADMESSAGE to parse.
Returns:
- Response payload as a string.
+ Parsed LeadMessage model.
"""
- await self.send_message(msg_type, payload, msg_id)
- return await self._receive_file(response_timeout=response_timeout)
-
- # Send a message that you do NOT need a response to
- async def send_message(
- self,
- msg_type: MessageType,
- payload: str | None,
- msg_id: int | None = None,
- ) -> None:
- """Send a message that does not require a response.
+ payload_str = msg.payload.decode("utf-8").strip("\x00")
+ root = ET.fromstring(payload_str)
+ return LeadMessage.model_validate(root)
+
+ def _decode_payload(self, data: bytes, compressed: bool) -> str:
+ """Decode a raw response payload, decompressing if necessary.
Args:
- msg_type: Type of message to send.
- payload: Optional payload string.
- msg_id: Optional message ID.
+ data: Raw payload bytes.
+ compressed: Whether the payload is zlib-compressed.
+
+ Returns:
+ Decoded UTF-8 string with leading/trailing null bytes stripped.
"""
- # If we aren't sending a specific msg_id, lets randomize it
- if not msg_id:
- msg_id = random.randrange(2**32)
+ if compressed:
+ data = zlib.decompress(data.rstrip(b"\x00"))
+ return data.decode("utf-8").strip("\x00")
- message = OmniLogicMessage(msg_id, msg_type, payload)
+ # -------------------------------------------------------------------------
+ # Public API
+ # -------------------------------------------------------------------------
- _LOGGER.debug("Sending Message %s", str(message))
+ async def async_send(self, msg_type: MessageType, payload: str) -> None:
+ """Send a message to the controller and wait for acknowledgement.
- await self._ensure_sent(message)
+ The message is retransmitted up to OMNI_RETRANSMIT_COUNT times if no ACK
+ is received within ACK_WAIT_TIMEOUT seconds.
- async def _send_ack(self, msg_id: int) -> None:
- """Send an ACK message for the given message ID."""
- body_element = ET.Element("Request", {"xmlns": XML_NAMESPACE})
- name_element = ET.SubElement(body_element, "Name")
- name_element.text = "Ack"
+ Args:
+ msg_type: The type of message to send.
+ payload: The XML payload string.
- req_body = ET.tostring(body_element, xml_declaration=True, encoding=XML_ENCODING)
- await self.send_message(MessageType.XML_ACK, req_body, msg_id)
+ Raises:
+ OmniConnectionError: If the transport is not available.
+ OmniTimeoutError: If no ACK is received after all retransmission attempts.
+ """
+ await self._send_with_retry(msg_type, payload)
- async def _receive_file(self, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT) -> str:
- """Wait for and reassemble a full response from the controller.
+ async def async_send_and_receive(self, msg_type: MessageType, payload: str) -> str:
+ """Send a message and receive the controller's response payload.
- Handles single and multi-block (LeadMessage/BlockMessage) responses.
+ Handles the full send → ACK → response (single or lead/block) flow,
+ including retransmission on ACK timeout and decompression of compressed
+ responses. If an ACK is received but the controller never sends the
+ expected follow-up response within DEFAULT_RESPONSE_TIMEOUT, the entire
+ send+receive cycle is retried up to OMNI_RETRANSMIT_COUNT times.
Args:
- response_timeout: Timeout in seconds to wait for the initial response.
+ msg_type: The type of message to send.
+ payload: The XML payload string.
Returns:
- Response payload as a string.
+ The decoded UTF-8 response string from the controller.
Raises:
- OmniTimeoutException: If a block message is not received in time.
- OmniFragmentationException: If fragment reassembly fails.
+ OmniConnectionError: If the transport is not available.
+ OmniTimeoutError: If no ACK or response is received within the allowed time.
+ OmniMessageFormatError: If an unparsable response datagram is received.
"""
- # wait for the initial packet.
- try:
- message = await asyncio.wait_for(self.data_queue.get(), response_timeout)
- except TimeoutError as exc:
- msg = f"Timeout waiting for response from controller: {exc}"
- raise OmniTimeoutError(msg) from exc
-
- # If messages have to be re-transmitted, we can sometimes receive multiple ACKs. The first one would be handled by
- # self._ensure_sent, but if any subsequent ACKs are sent to us, we need to dump them and wait for a "real" message.
- while message.type in [MessageType.ACK, MessageType.XML_ACK]:
- _LOGGER.debug("Skipping duplicate ACK message")
+ for attempt in range(OMNI_RETRANSMIT_COUNT + 1):
+ await self._send_with_retry(msg_type, payload)
try:
- message = await asyncio.wait_for(self.data_queue.get(), response_timeout)
- except TimeoutError as exc:
- msg = f"Timeout waiting for response from controller: {exc}"
- raise OmniTimeoutError(msg) from exc
-
- await self._send_ack(message.id)
-
- # If the response is too large, the controller will send a LeadMessage indicating how many follow-up messages will be sent
- if message.type is MessageType.MSP_LEADMESSAGE:
- try:
- leadmsg = LeadMessage.model_validate(ET.fromstring(message.payload[:-1]))
- except Exception as exc:
- msg = f"Failed to parse LeadMessage: {exc}"
- raise OmniFragmentationError(msg) from exc
-
- _LOGGER.debug("Will receive %s blockmessages for fragmented response", leadmsg.msg_block_count)
-
- # Wait for the block data data
- retval: bytes = b""
- # If we received a LeadMessage, continue to receive messages until we have all of our data
- # Fragments of data may arrive out of order, so we store them in a buffer as they arrive and sort them after
- data_fragments: dict[int, bytes] = {}
- fragment_start_time = time.time()
-
- while len(data_fragments) < leadmsg.msg_block_count:
- # Check if we've been waiting too long for fragments
- if time.time() - fragment_start_time > MAX_FRAGMENT_WAIT_TIME:
- _LOGGER.error(
- "Timeout waiting for fragments: received %d/%d after %ds",
- len(data_fragments),
- leadmsg.msg_block_count,
- MAX_FRAGMENT_WAIT_TIME,
- )
- msg = (
- f"Timeout waiting for fragments: received {len(data_fragments)}/{leadmsg.msg_block_count} "
- f"after {MAX_FRAGMENT_WAIT_TIME}s"
+ async with asyncio.timeout(DEFAULT_RESPONSE_TIMEOUT):
+ raw_data, compressed = await self._receive_response()
+ return self._decode_payload(raw_data, compressed)
+ except TimeoutError:
+ if attempt < OMNI_RETRANSMIT_COUNT:
+ _LOGGER.debug(
+ "no response received for %s within %ds, retrying (attempt %d/%d)",
+ msg_type.name,
+ DEFAULT_RESPONSE_TIMEOUT,
+ attempt + 1,
+ OMNI_RETRANSMIT_COUNT + 1,
)
- raise OmniFragmentationError(msg)
-
- # We need to wait long enough for the Omni to get through all of it's retries before we bail out.
- try:
- resp = await asyncio.wait_for(self.data_queue.get(), response_timeout)
- except TimeoutError as exc:
- msg = f"Timeout receiving fragment: got {len(data_fragments)}/{leadmsg.msg_block_count} fragments: {exc}"
- raise OmniFragmentationError(msg) from exc
-
- # We only want to collect blockmessages here
- if resp.type is not MessageType.MSP_BLOCKMESSAGE:
- _LOGGER.debug("Received a message other than a blockmessage during fragmentation: %s", resp.type)
- continue
-
- await self._send_ack(resp.id)
-
- # remove an 8 byte header to get to the payload data
- data_fragments[resp.id] = resp.payload[BLOCK_MESSAGE_HEADER_OFFSET:]
- _LOGGER.debug("Received fragment %d/%d", len(data_fragments), leadmsg.msg_block_count)
- # Reassemble the fragmets in order
- for _, data in sorted(data_fragments.items()):
- retval += data
-
- _LOGGER.debug("Successfully reassembled %d fragments into %d bytes", leadmsg.msg_block_count, len(retval))
-
- # We did not receive a LeadMessage, so our payload is just this one packet
- else:
- retval = message.payload
-
- # Decompress the returned data if necessary
- if message.compressed:
- _LOGGER.debug("Decompressing response payload")
- try:
- comp_bytes = bytes.fromhex(retval.hex())
- retval = zlib.decompress(comp_bytes)
- _LOGGER.debug("Decompressed %d bytes to %d bytes", len(comp_bytes), len(retval))
- except zlib.error as exc:
- msg = f"Failed to decompress message: {exc}"
- raise OmniMessageFormatError(msg) from exc
-
- # For some API calls, the Omni null terminates the response, we are stripping that here to make parsing it later easier
- return retval.decode("utf-8").strip("\x00")
+ msg = f"No response received for {msg_type.name} after {OMNI_RETRANSMIT_COUNT + 1} attempts"
+ raise OmniTimeoutError(msg)
diff --git a/pyomnilogic_local/cli/pcap_utils.py b/pyomnilogic_local/cli/pcap_utils.py
index cdfa036..f4e91c9 100644
--- a/pyomnilogic_local/cli/pcap_utils.py
+++ b/pyomnilogic_local/cli/pcap_utils.py
@@ -15,7 +15,7 @@
from scapy.layers.inet import UDP
from scapy.utils import rdpcap
-from pyomnilogic_local.api.protocol import OmniLogicMessage
+from pyomnilogic_local.api.message import OmniLogicMessage
from pyomnilogic_local.models.leadmessage import LeadMessage
from pyomnilogic_local.omnitypes import MessageType
diff --git a/pyomnilogic_local/models/leadmessage.py b/pyomnilogic_local/models/leadmessage.py
index d3f39f4..8817ef6 100644
--- a/pyomnilogic_local/models/leadmessage.py
+++ b/pyomnilogic_local/models/leadmessage.py
@@ -1,3 +1,4 @@
+# ruff: noqa: TC001 # pydantic relies on the omnitypes imports at runtime
from __future__ import annotations
from typing import Any
@@ -5,6 +6,8 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator
+from pyomnilogic_local.omnitypes import MessageType
+
from .const import XML_NS
# Example Lead Message XML:
@@ -24,7 +27,7 @@
class LeadMessage(BaseModel):
model_config = ConfigDict(from_attributes=True)
- source_op_id: int = Field(alias="SourceOpId")
+ source_op_id: MessageType = Field(alias="SourceOpId")
msg_size: int = Field(alias="MsgSize")
msg_block_count: int = Field(alias="MsgBlockCount")
type: int = Field(alias="Type")
diff --git a/pyomnilogic_local/omnitypes.py b/pyomnilogic_local/omnitypes.py
index 6a16062..1831bf9 100644
--- a/pyomnilogic_local/omnitypes.py
+++ b/pyomnilogic_local/omnitypes.py
@@ -32,8 +32,8 @@ class MessageType(PrettyEnum, IntEnum):
GET_FILTER_DIAGNOSTIC_INFO = 386
HANDSHAKE = 1000
ACK = 1002
- MSP_TELEMETRY_UPDATE = 1004
MSP_CONFIGURATIONUPDATE = 1003
+ MSP_TELEMETRY_UPDATE = 1004
MSP_ALARM_LIST_RESPONSE = 1304
MSP_LEADMESSAGE = 1998
MSP_BLOCKMESSAGE = 1999
@@ -324,6 +324,7 @@ class ZodiacShow(PrettyEnum, IntEnum):
class ColorLogicPowerState(PrettyEnum, IntEnum):
OFF = 0
POWERING_OFF = 1
+ INITIALIZING = 2 # The app shows this as 15 seconds of white, but this state seems to happen when the Omni first powers up
CHANGING_SHOW = 3
FIFTEEN_SECONDS_WHITE = 4
ACTIVE = 6
diff --git a/tests/test_api.py b/tests/test_api.py
index 3ed3775..019b087 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -186,7 +186,7 @@ async def test_async_get_mspconfig_generates_valid_xml() -> None:
"""Test that async_get_mspconfig generates valid XML request."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send_and_receive", new_callable=AsyncMock) as mock_send:
mock_send.return_value = 'Configuration'
await api.async_get_mspconfig(raw=True)
@@ -206,7 +206,7 @@ async def test_async_get_telemetry_generates_valid_xml() -> None:
"""Test that async_get_telemetry generates valid XML request."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send_and_receive", new_callable=AsyncMock) as mock_send:
mock_send.return_value = 'Telemetry'
await api.async_get_telemetry(raw=True)
@@ -226,7 +226,7 @@ async def test_async_get_filter_diagnostics_generates_valid_xml() -> None:
"""Test that async_get_filter_diagnostics generates valid XML with correct parameters."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send_and_receive", new_callable=AsyncMock) as mock_send:
mock_send.return_value = 'FilterDiagnostics'
await api.async_get_filter_diagnostics(pool_id=1, equipment_id=2, raw=True)
@@ -248,7 +248,7 @@ async def test_async_set_heater_generates_valid_xml() -> None:
"""Test that async_set_heater generates valid XML with correct parameters."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_heater(pool_id=1, equipment_id=2, temperature=75)
@@ -273,7 +273,7 @@ async def test_async_set_filter_speed_generates_valid_xml() -> None:
"""Test that async_set_filter_speed generates valid XML with correct parameters."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_filter_speed(pool_id=1, equipment_id=2, speed=75)
@@ -296,7 +296,7 @@ async def test_async_set_equipment_generates_valid_xml() -> None:
"""Test that async_set_equipment generates valid XML with correct parameters."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_equipment(
@@ -339,7 +339,7 @@ async def test_async_set_heater_mode_generates_valid_xml() -> None:
"""Test that async_set_heater_mode generates valid XML with correct enum values."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_heater_mode(pool_id=1, equipment_id=2, mode=HeaterMode.HEAT)
@@ -360,7 +360,7 @@ async def test_async_set_light_show_generates_valid_xml() -> None:
"""Test that async_set_light_show generates valid XML with correct enum values."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_light_show(
@@ -399,7 +399,7 @@ async def test_async_set_chlorinator_enable_boolean_conversion(subtests: pytest.
]
for enabled, expected, description in test_cases:
- with subtests.test(msg=description), patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with subtests.test(msg=description), patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_chlorinator_enable(pool_id=1, enabled=enabled)
@@ -424,7 +424,7 @@ async def test_async_set_heater_enable_boolean_conversion(subtests: pytest.Subte
]
for enabled, expected, description in test_cases:
- with subtests.test(msg=description), patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with subtests.test(msg=description), patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_heater_enable(pool_id=1, equipment_id=2, enabled=enabled)
@@ -441,7 +441,7 @@ async def test_async_set_chlorinator_params_generates_valid_xml() -> None:
"""Test that async_set_chlorinator_params generates valid XML with all parameters."""
api = OmniLogicAPI("192.168.1.100")
- with patch.object(api, "async_send_message", new_callable=AsyncMock) as mock_send:
+ with patch.object(api, "async_send", new_callable=AsyncMock) as mock_send:
mock_send.return_value = None
await api.async_set_chlorinator_params(
@@ -489,12 +489,12 @@ async def test_async_send_message_creates_transport() -> None:
mock_transport = MagicMock()
mock_protocol = AsyncMock()
- mock_protocol.send_message = AsyncMock()
+ mock_protocol.async_send = AsyncMock()
with patch("asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol))
- await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=False)
+ await api.async_send(MessageType.REQUEST_CONFIGURATION, "test")
# Verify endpoint was created with correct parameters
mock_loop.return_value.create_datagram_endpoint.assert_called_once()
@@ -512,15 +512,15 @@ async def test_async_send_message_with_response() -> None:
mock_transport = MagicMock()
mock_protocol = AsyncMock()
- mock_protocol.send_and_receive = AsyncMock(return_value="test response")
+ mock_protocol.async_send_and_receive = AsyncMock(return_value="test response")
with patch("asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol))
- result = await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=True)
+ result = await api.async_send_and_receive(MessageType.REQUEST_CONFIGURATION, "test")
assert result == "test response"
- mock_protocol.send_and_receive.assert_called_once()
+ mock_protocol.async_send_and_receive.assert_called_once()
mock_transport.close.assert_called_once()
@@ -531,15 +531,15 @@ async def test_async_send_message_without_response() -> None:
mock_transport = MagicMock()
mock_protocol = AsyncMock()
- mock_protocol.send_message = AsyncMock()
+ mock_protocol.async_send = AsyncMock()
with patch("asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol))
- result = await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=False) # type: ignore[func-returns-value]
+ result = await api.async_send(MessageType.REQUEST_CONFIGURATION, "test") # type: ignore[func-returns-value]
assert result is None
- mock_protocol.send_message.assert_called_once()
+ mock_protocol.async_send.assert_called_once()
mock_transport.close.assert_called_once()
@@ -550,13 +550,13 @@ async def test_async_send_message_closes_transport_on_error() -> None:
mock_transport = MagicMock()
mock_protocol = AsyncMock()
- mock_protocol.send_message = AsyncMock(side_effect=Exception("Test error"))
+ mock_protocol.async_send = AsyncMock(side_effect=Exception("Test error"))
with patch("asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.create_datagram_endpoint = AsyncMock(return_value=(mock_transport, mock_protocol))
with pytest.raises(Exception, match="Test error"):
- await api.async_send_message(MessageType.REQUEST_CONFIGURATION, "test", need_response=False)
+ await api.async_send(MessageType.REQUEST_CONFIGURATION, "test")
# Verify transport was still closed despite the error
mock_transport.close.assert_called_once()
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index d152b39..93bf179 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -14,14 +14,15 @@
import struct
import time
import zlib
-from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from xml.etree import ElementTree as ET
import pytest
-from pyomnilogic_local.api.exceptions import OmniFragmentationError, OmniMessageFormatError, OmniTimeoutError
-from pyomnilogic_local.api.protocol import OmniLogicMessage, OmniLogicProtocol
+from pyomnilogic_local.api.constants import OMNI_RETRANSMIT_COUNT
+from pyomnilogic_local.api.exceptions import OmniMessageFormatError, OmniTimeoutError
+from pyomnilogic_local.api.message import OmniLogicMessage
+from pyomnilogic_local.api.protocol import OmniLogicProtocol
from pyomnilogic_local.omnitypes import ClientType, MessageType
# ============================================================================
@@ -187,8 +188,7 @@ def test_message_payload_null_termination() -> None:
def test_protocol_initialization() -> None:
"""Test that protocol initializes with correct queue sizes."""
protocol = OmniLogicProtocol()
- assert protocol.data_queue.maxsize == 100
- assert protocol.error_queue.maxsize == 100
+ assert protocol._receive_queue.maxsize == 100
def test_protocol_connection_made() -> None:
@@ -198,16 +198,7 @@ def test_protocol_connection_made() -> None:
protocol.connection_made(mock_transport)
- assert protocol.transport is mock_transport
-
-
-def test_protocol_connection_lost_with_exception() -> None:
- """Test that connection_lost raises exception if provided."""
- protocol = OmniLogicProtocol()
- test_exception = RuntimeError("Connection error")
-
- with pytest.raises(RuntimeError, match="Connection error"):
- protocol.connection_lost(test_exception)
+ assert protocol._transport is mock_transport
def test_protocol_connection_lost_without_exception() -> None:
@@ -224,14 +215,16 @@ def test_protocol_connection_lost_without_exception() -> None:
def test_datagram_received_valid_message() -> None:
"""Test that valid messages are added to the queue."""
protocol = OmniLogicProtocol()
- valid_data = bytes(OmniLogicMessage(123, MessageType.ACK))
+ # ACK/XML_ACK messages resolve futures, not the queue; use a non-ACK type
+ valid_data = bytes(OmniLogicMessage(123, MessageType.MSP_CONFIGURATIONUPDATE))
protocol.datagram_received(valid_data, ("127.0.0.1", 12345))
- assert protocol.data_queue.qsize() == 1
- message = protocol.data_queue.get_nowait()
+ assert protocol._receive_queue.qsize() == 1
+ message = protocol._receive_queue.get_nowait()
+ assert isinstance(message, OmniLogicMessage)
assert message.id == 123
- assert message.type == MessageType.ACK
+ assert message.type == MessageType.MSP_CONFIGURATIONUPDATE
def test_datagram_received_with_corrupt_data(caplog: pytest.LogCaptureFixture) -> None:
@@ -239,51 +232,48 @@ def test_datagram_received_with_corrupt_data(caplog: pytest.LogCaptureFixture) -
protocol = OmniLogicProtocol()
corrupt_data = b"short"
- with caplog.at_level("ERROR"):
+ with caplog.at_level("WARNING"):
protocol.datagram_received(corrupt_data, ("127.0.0.1", 12345))
- assert any("Failed to parse incoming datagram" in r.message for r in caplog.records)
- assert protocol.error_queue.qsize() == 1
+ assert any("received unparsable datagram" in r.message for r in caplog.records)
+ # The parse error is placed on the receive queue as an OmniMessageFormatError
+ assert protocol._receive_queue.qsize() == 1
+ assert isinstance(protocol._receive_queue.get_nowait(), OmniMessageFormatError)
-def test_datagram_received_queue_overflow(caplog: pytest.LogCaptureFixture) -> None:
- """Test that queue overflow is handled and logged."""
+def test_datagram_received_queue_overflow() -> None:
+ """Test that QueueFull is raised when the receive queue is full."""
protocol = OmniLogicProtocol()
- protocol.data_queue = asyncio.Queue(maxsize=1)
- protocol.data_queue.put_nowait(OmniLogicMessage(1, MessageType.ACK))
+ protocol._receive_queue = asyncio.Queue(maxsize=1)
+ # ACKs resolve futures, not the queue; fill with a non-ACK message
+ protocol._receive_queue.put_nowait(OmniLogicMessage(1, MessageType.MSP_CONFIGURATIONUPDATE))
- valid_data = bytes(OmniLogicMessage(2, MessageType.ACK))
- with caplog.at_level("ERROR"):
+ valid_data = bytes(OmniLogicMessage(2, MessageType.MSP_CONFIGURATIONUPDATE))
+ with pytest.raises(asyncio.QueueFull):
protocol.datagram_received(valid_data, ("127.0.0.1", 12345))
- assert any("Data queue is full" in r.message for r in caplog.records)
-
-def test_datagram_received_unexpected_exception(caplog: pytest.LogCaptureFixture) -> None:
- """Test that unexpected exceptions during datagram processing are handled."""
+def test_datagram_received_unexpected_exception() -> None:
+ """Test that unexpected exceptions during datagram processing propagate to the caller."""
protocol = OmniLogicProtocol()
- # Patch OmniLogicMessage.from_bytes to raise an unexpected exception
+ # Only OmniMessageFormatError is caught; any other exception propagates unhandled
with (
patch("pyomnilogic_local.api.protocol.OmniLogicMessage.from_bytes", side_effect=RuntimeError("Unexpected")),
- caplog.at_level("ERROR"),
+ pytest.raises(RuntimeError, match="Unexpected"),
):
protocol.datagram_received(b"data", ("127.0.0.1", 12345))
- assert any("Unexpected error processing datagram" in r.message for r in caplog.records)
- assert protocol.error_queue.qsize() == 1
-
-def test_error_received() -> None:
- """Test that error_received puts errors in the error queue."""
+def test_error_received(caplog: pytest.LogCaptureFixture) -> None:
+ """Test that error_received logs transport errors."""
protocol = OmniLogicProtocol()
test_error = RuntimeError("UDP error")
- protocol.error_received(test_error)
+ with caplog.at_level("ERROR"):
+ protocol.error_received(test_error)
- assert protocol.error_queue.qsize() == 1
- error = protocol.error_queue.get_nowait()
- assert error is test_error
+ assert any("transport error" in r.message for r in caplog.records)
# ============================================================================
@@ -293,91 +283,99 @@ def test_error_received() -> None:
@pytest.mark.asyncio
async def test_wait_for_ack_success() -> None:
- """Test successful ACK waiting."""
+ """Test that an ACK for the correct ID resolves the pending future."""
+ loop = asyncio.get_running_loop()
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Put an ACK message in the queue
+ ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future()
+ protocol._ack_futures[123] = ack_future
+
ack_message = OmniLogicMessage(123, MessageType.ACK)
- await protocol.data_queue.put(ack_message)
+ protocol._resolve_ack(ack_message)
- # Should return without raising
- await protocol._wait_for_ack(123)
+ assert ack_future.done()
+ assert ack_future.result() is ack_message
@pytest.mark.asyncio
async def test_wait_for_ack_wrong_id_continues_waiting() -> None:
- """Test that wrong ACK IDs are consumed and waiting continues for the correct one."""
+ """Test that an ACK for a different ID does not resolve the pending future."""
+ loop = asyncio.get_running_loop()
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Put wrong ID first, then correct ID
- wrong_ack = OmniLogicMessage(999, MessageType.ACK)
- correct_ack = OmniLogicMessage(123, MessageType.ACK)
+ ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future()
+ protocol._ack_futures[123] = ack_future
- await protocol.data_queue.put(wrong_ack)
- await protocol.data_queue.put(correct_ack)
+ wrong_ack = OmniLogicMessage(999, MessageType.ACK)
+ protocol._resolve_ack(wrong_ack)
- await protocol._wait_for_ack(123)
- # Queue should be empty after consuming both messages
- assert protocol.data_queue.qsize() == 0
+ assert not ack_future.done()
@pytest.mark.asyncio
-async def test_wait_for_ack_leadmessage_instead(caplog: pytest.LogCaptureFixture) -> None:
- """Test that LeadMessage with matching ID is accepted (ACK was dropped)."""
+async def test_wait_for_ack_leadmessage_instead() -> None:
+ """Test that a LeadMessage with a matching ID goes to the receive queue, not the ACK future."""
+ loop = asyncio.get_running_loop()
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Put a LeadMessage with matching ID (simulating dropped ACK)
- leadmsg = OmniLogicMessage(123, MessageType.MSP_LEADMESSAGE)
- await protocol.data_queue.put(leadmsg)
+ ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future()
+ protocol._ack_futures[123] = ack_future
- with caplog.at_level("DEBUG"):
- await protocol._wait_for_ack(123)
+ # Build a valid LeadMessage datagram and deliver it via datagram_received
+ leadmsg_payload = (
+ ''
+ "LeadMessage"
+ '1003'
+ '10'
+ '1'
+ '0'
+ ""
+ )
+ leadmsg = OmniLogicMessage(123, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
+ protocol.datagram_received(bytes(leadmsg), ("127.0.0.1", 12345))
- # With matching ID, it's treated as the ACK we're looking for
- assert any("Received ACK for message ID 123" in r.message for r in caplog.records)
- # LeadMessage should NOT be in queue since IDs matched
- assert protocol.data_queue.qsize() == 0
+ # LeadMessages are not ACKs — the future should remain unresolved
+ assert not ack_future.done()
+ # The LeadMessage is placed on the receive queue for response assembly
+ assert protocol._receive_queue.qsize() == 1
@pytest.mark.asyncio
-async def test_wait_for_ack_leadmessage_wrong_id(caplog: pytest.LogCaptureFixture) -> None:
- """Test that LeadMessage with wrong ID is put back in queue and waiting continues."""
+async def test_wait_for_ack_leadmessage_wrong_id() -> None:
+ """Test that a LeadMessage with a different ID also goes to the receive queue, not the ACK future."""
+ loop = asyncio.get_running_loop()
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Put a LeadMessage with wrong ID, then correct ACK
- leadmsg = OmniLogicMessage(999, MessageType.MSP_LEADMESSAGE)
- correct_ack = OmniLogicMessage(123, MessageType.ACK)
+ ack_future: asyncio.Future[OmniLogicMessage] = loop.create_future()
+ protocol._ack_futures[123] = ack_future
- await protocol.data_queue.put(leadmsg)
- await protocol.data_queue.put(correct_ack)
-
- with caplog.at_level("DEBUG"):
- await protocol._wait_for_ack(123)
+ leadmsg_payload = (
+ ''
+ "LeadMessage"
+ '1003'
+ '10'
+ '1'
+ '0'
+ ""
+ )
+ leadmsg = OmniLogicMessage(999, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
+ protocol.datagram_received(bytes(leadmsg), ("127.0.0.1", 12345))
- # Should log that ACK was dropped and put LeadMessage back
- assert any("ACK was dropped" in r.message for r in caplog.records)
- # Both messages were consumed and LeadMessage was put back, so queue should have 1 item
- # But the ACK was also consumed, so we actually end up with just the LeadMessage back
- # Actually, looking at the code: LeadMessage gets put back, then we return
- # So BOTH the correct ACK and the LeadMessage should be in the queue
- assert protocol.data_queue.qsize() == 2 # LeadMessage put back, correct ACK also still there
+ # Future for ID 123 should remain unresolved
+ assert not ack_future.done()
+ # The LeadMessage goes to the receive queue regardless of ID
+ assert protocol._receive_queue.qsize() == 1
@pytest.mark.asyncio
async def test_wait_for_ack_error_in_queue() -> None:
- """Test that errors from error queue are raised."""
+ """Test that _send_with_retry raises OmniTimeoutError when no ACK is received."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
-
- test_error = RuntimeError("Test error")
- await protocol.error_queue.put(test_error)
+ protocol._transport = MagicMock()
- with pytest.raises(RuntimeError, match="Test error"):
- await protocol._wait_for_ack(123)
+ # Patch wait_for to always raise TimeoutError so every attempt times out immediately
+ with patch("asyncio.wait_for", side_effect=TimeoutError), pytest.raises(OmniTimeoutError):
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
# ============================================================================
@@ -387,64 +385,61 @@ async def test_wait_for_ack_error_in_queue() -> None:
@pytest.mark.asyncio
async def test_ensure_sent_ack_message() -> None:
- """Test that ACK messages don't wait for ACK."""
+ """Test that _send_xml_ack transmits directly without waiting for an ACK."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
- ack_message = OmniLogicMessage(123, MessageType.ACK)
-
- # Should return immediately without waiting
- await protocol._ensure_sent(ack_message)
+ protocol._send_xml_ack(123)
- protocol.transport.sendto.assert_called_once()
+ protocol._transport.sendto.assert_called_once()
@pytest.mark.asyncio
async def test_ensure_sent_xml_ack_message() -> None:
- """Test that XML_ACK messages don't wait for ACK."""
+ """Test that _send_xml_ack sends the correct XML_ACK message format."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
-
- xml_ack_message = OmniLogicMessage(123, MessageType.XML_ACK, payload="")
+ protocol._transport = MagicMock()
- await protocol._ensure_sent(xml_ack_message)
+ protocol._send_xml_ack(456)
- protocol.transport.sendto.assert_called_once()
+ protocol._transport.sendto.assert_called_once()
+ sent_bytes = protocol._transport.sendto.call_args[0][0]
+ # Verify the sent bytes parse back to an XML_ACK message with the correct ID
+ parsed = OmniLogicMessage.from_bytes(sent_bytes)
+ assert parsed.type == MessageType.XML_ACK
+ assert parsed.id == 456
@pytest.mark.asyncio
async def test_ensure_sent_success_first_attempt() -> None:
- """Test successful send on first attempt."""
+ """Test successful send on first attempt with _send_with_retry."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
- # Mock _wait_for_ack to succeed immediately
- with patch.object(protocol, "_wait_for_ack", new_callable=AsyncMock) as mock_wait:
- message = OmniLogicMessage(123, MessageType.REQUEST_CONFIGURATION)
- await protocol._ensure_sent(message, max_attempts=3)
+ # Simulate an immediate ACK by resolving the future when sendto is called
+ def resolve_ack_on_send(data: bytes) -> None:
+ msg = OmniLogicMessage.from_bytes(data)
+ protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK))
- protocol.transport.sendto.assert_called_once()
- mock_wait.assert_called_once_with(123)
+ protocol._transport.sendto.side_effect = resolve_ack_on_send
+
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
+
+ protocol._transport.sendto.assert_called_once()
@pytest.mark.asyncio
async def test_ensure_sent_timeout_and_retry_logs(caplog: pytest.LogCaptureFixture) -> None:
- """Test that _ensure_sent logs retries and raises on repeated timeout."""
+ """Test that _send_with_retry logs retries and raises OmniTimeoutError after all attempts."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
- async def always_timeout(*args: object, **kwargs: object) -> None: # noqa: ARG001
- await asyncio.sleep(0)
- raise TimeoutError
+ with patch("asyncio.wait_for", side_effect=TimeoutError), caplog.at_level("DEBUG"), pytest.raises(OmniTimeoutError):
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
- message = OmniLogicMessage(123, MessageType.REQUEST_CONFIGURATION)
- with patch.object(protocol, "_wait_for_ack", always_timeout), caplog.at_level("WARNING"), pytest.raises(OmniTimeoutError):
- await protocol._ensure_sent(message, max_attempts=3)
-
- assert any("attempt 1/3" in r.message for r in caplog.records)
- assert any("attempt 2/3" in r.message for r in caplog.records)
- assert any("after 3 attempts" in r.message for r in caplog.records)
- assert protocol.transport.sendto.call_count == 3
+ retry_logs = [r for r in caplog.records if "no ACK" in r.message]
+ assert len(retry_logs) == OMNI_RETRANSMIT_COUNT
+ assert protocol._transport.sendto.call_count == OMNI_RETRANSMIT_COUNT + 1
# ============================================================================
@@ -454,47 +449,63 @@ async def always_timeout(*args: object, **kwargs: object) -> None: # noqa: ARG0
@pytest.mark.asyncio
async def test_send_message_generates_random_id() -> None:
- """Test that send_message generates a random ID when none provided."""
+ """Test that _send_with_retry generates a non-zero message ID."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
+
+ captured_ids: list[int] = []
- with patch.object(protocol, "_ensure_sent", new_callable=AsyncMock) as mock_ensure:
- await protocol.send_message(MessageType.REQUEST_CONFIGURATION, None, msg_id=None)
+ def capture_id(data: bytes) -> None:
+ msg = OmniLogicMessage.from_bytes(data)
+ captured_ids.append(msg.id)
+ protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK))
- mock_ensure.assert_called_once()
- sent_message = mock_ensure.call_args[0][0]
- assert sent_message.id != 0 # Should have a random ID
+ protocol._transport.sendto.side_effect = capture_id
+
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
+
+ assert len(captured_ids) == 1
+ assert captured_ids[0] != 0
@pytest.mark.asyncio
-async def test_send_message_uses_provided_id() -> None:
- """Test that send_message uses provided ID."""
+async def test_send_message_uses_incrementing_id() -> None:
+ """Test that each _send_with_retry call uses a distinct, incrementing message ID."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
+
+ captured_ids: list[int] = []
- with patch.object(protocol, "_ensure_sent", new_callable=AsyncMock) as mock_ensure:
- await protocol.send_message(MessageType.REQUEST_CONFIGURATION, None, msg_id=12345)
+ def capture_and_ack(data: bytes) -> None:
+ msg = OmniLogicMessage.from_bytes(data)
+ captured_ids.append(msg.id)
+ protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK))
- mock_ensure.assert_called_once()
- sent_message = mock_ensure.call_args[0][0]
- assert sent_message.id == 12345
+ protocol._transport.sendto.side_effect = capture_and_ack
+
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
+
+ assert len(captured_ids) == 2
+ assert captured_ids[0] != captured_ids[1]
@pytest.mark.asyncio
async def test_send_and_receive() -> None:
- """Test send_and_receive calls send_message and _receive_file."""
+ """Test async_send_and_receive calls _send_with_retry and _receive_response."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
with (
- patch.object(protocol, "send_message", new_callable=AsyncMock) as mock_send,
- patch.object(protocol, "_receive_file", new_callable=AsyncMock) as mock_receive,
+ patch.object(protocol, "_send_with_retry", new_callable=AsyncMock) as mock_send,
+ patch.object(protocol, "_receive_response", new_callable=AsyncMock) as mock_receive,
+ patch.object(protocol, "_decode_payload", return_value="test response"),
):
- mock_receive.return_value = "test response"
+ mock_receive.return_value = (b"raw", False)
- result = await protocol.send_and_receive(MessageType.REQUEST_CONFIGURATION, "payload", 123)
+ result = await protocol.async_send_and_receive(MessageType.REQUEST_CONFIGURATION, "")
- mock_send.assert_called_once_with(MessageType.REQUEST_CONFIGURATION, "payload", 123)
+ mock_send.assert_called_once_with(MessageType.REQUEST_CONFIGURATION, "")
mock_receive.assert_called_once()
assert result == "test response"
@@ -506,22 +517,22 @@ async def test_send_and_receive() -> None:
@pytest.mark.asyncio
async def test_send_ack_generates_xml() -> None:
- """Test that _send_ack generates proper XML ACK message."""
+ """Test that _send_xml_ack transmits a well-formed XML_ACK message."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
+ protocol._transport = MagicMock()
- with patch.object(protocol, "send_message", new_callable=AsyncMock) as mock_send:
- await protocol._send_ack(12345)
+ protocol._send_xml_ack(12345)
- mock_send.assert_called_once()
- call_args = mock_send.call_args
- assert call_args[0][0] == MessageType.XML_ACK
- assert call_args[0][2] == 12345
+ protocol._transport.sendto.assert_called_once()
+ sent_bytes = protocol._transport.sendto.call_args[0][0]
- # Verify XML structure
- xml_payload = call_args[0][1]
+ parsed = OmniLogicMessage.from_bytes(sent_bytes)
+ assert parsed.type == MessageType.XML_ACK
+ assert parsed.id == 12345
+
+ # Verify XML structure contains the expected Ack name element
+ xml_payload = parsed.payload.rstrip(b"\x00").decode("utf-8")
root = ET.fromstring(xml_payload)
- assert "Request" in root.tag
name_elem = root.find(".//{http://nextgen.hayward.com/api}Name")
assert name_elem is not None
assert name_elem.text == "Ack"
@@ -534,81 +545,62 @@ async def test_send_ack_generates_xml() -> None:
@pytest.mark.asyncio
async def test_receive_file_simple_response() -> None:
- """Test receiving a simple (non-fragmented) response."""
+ """Test receiving a simple (non-fragmented) response via _receive_response."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Create a simple response message
- response_msg = OmniLogicMessage(123, MessageType.GET_TELEMETRY, payload="")
- await protocol.data_queue.put(response_msg)
+ response_msg = OmniLogicMessage(123, MessageType.MSP_TELEMETRY_UPDATE)
+ response_msg.payload = b""
+ response_msg.compressed = False
+ await protocol._receive_queue.put(response_msg)
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock) as mock_ack:
- result = await protocol._receive_file()
+ raw_data, compressed = await protocol._receive_response()
- mock_ack.assert_called_once_with(123)
- assert result == ""
+ assert raw_data == b""
+ assert compressed is False
@pytest.mark.asyncio
async def test_receive_file_skips_duplicate_acks(caplog: pytest.LogCaptureFixture) -> None:
- """Test that duplicate ACKs are skipped."""
+ """Test that orphaned block messages are skipped and the real response is returned."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Put duplicate ACKs followed by real message
- ack1 = OmniLogicMessage(111, MessageType.ACK)
- ack2 = OmniLogicMessage(222, MessageType.XML_ACK)
- response = OmniLogicMessage(333, MessageType.GET_TELEMETRY, payload="")
+ # Stray block message (no preceding lead message) should be skipped
+ stray_block = OmniLogicMessage(111, MessageType.MSP_BLOCKMESSAGE)
+ stray_block.payload = b"\x00" * 8 + b"stray"
+ response = OmniLogicMessage(333, MessageType.MSP_CONFIGURATIONUPDATE)
+ response.payload = b""
+ response.compressed = False
- await protocol.data_queue.put(ack1)
- await protocol.data_queue.put(ack2)
- await protocol.data_queue.put(response)
+ await protocol._receive_queue.put(stray_block)
+ await protocol._receive_queue.put(response)
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock), caplog.at_level("DEBUG"):
- result = await protocol._receive_file()
+ with caplog.at_level("WARNING"):
+ raw_data, _compressed = await protocol._receive_response()
- assert any("Skipping duplicate ACK" in r.message for r in caplog.records)
- assert result == ""
+ assert any("block message" in r.message for r in caplog.records)
+ assert raw_data == b""
@pytest.mark.asyncio
async def test_receive_file_decompresses_data() -> None:
- """Test that compressed responses are decompressed."""
+ """Test that _decode_payload decompresses compressed responses."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Create compressed payload
original = b"This is test data that will be compressed"
- compressed = zlib.compress(original)
-
- # Create message with compressed payload
- response_msg = OmniLogicMessage(123, MessageType.GET_TELEMETRY)
- response_msg.compressed = True
- response_msg.payload = compressed
+ compressed_data = zlib.compress(original)
- await protocol.data_queue.put(response_msg)
-
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock):
- result = await protocol._receive_file()
+ result = protocol._decode_payload(compressed_data, compressed=True)
assert result == original.decode("utf-8")
@pytest.mark.asyncio
async def test_receive_file_decompression_error() -> None:
- """Test that decompression errors are handled."""
+ """Test that _decode_payload raises OmniMessageFormatError for invalid compressed data."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
-
- # Create message with invalid compressed data
- response_msg = OmniLogicMessage(123, MessageType.GET_TELEMETRY)
- response_msg.compressed = True
- response_msg.payload = b"invalid compressed data"
-
- await protocol.data_queue.put(response_msg)
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock), pytest.raises(OmniMessageFormatError, match="Failed to decompress"):
- await protocol._receive_file()
+ with pytest.raises(zlib.error):
+ protocol._decode_payload(b"invalid compressed data", compressed=True)
# ============================================================================
@@ -618,11 +610,9 @@ async def test_receive_file_decompression_error() -> None:
@pytest.mark.asyncio
async def test_receive_file_fragmented_response() -> None:
- """Test receiving a fragmented response with LeadMessage and BlockMessages."""
+ """Test receiving a fragmented response with LeadMessage and BlockMessages via _receive_response."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Create LeadMessage
leadmsg_payload = (
'LeadMessage'
'100324'
@@ -631,30 +621,26 @@ async def test_receive_file_fragmented_response() -> None:
)
leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
- # Create BlockMessages with 8-byte header
block1 = OmniLogicMessage(101, MessageType.MSP_BLOCKMESSAGE)
block1.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"first_part"
block2 = OmniLogicMessage(102, MessageType.MSP_BLOCKMESSAGE)
block2.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"second_part"
- await protocol.data_queue.put(leadmsg)
- await protocol.data_queue.put(block1)
- await protocol.data_queue.put(block2)
+ await protocol._receive_queue.put(leadmsg)
+ await protocol._receive_queue.put(block1)
+ await protocol._receive_queue.put(block2)
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock) as mock_ack:
- result = await protocol._receive_file()
+ raw_data, compressed = await protocol._receive_response()
- # Should send ACK for LeadMessage and each BlockMessage
- assert mock_ack.call_count == 3
- assert result == "first_partsecond_part"
+ assert raw_data == b"first_partsecond_part"
+ assert compressed is False
@pytest.mark.asyncio
async def test_receive_file_fragmented_out_of_order() -> None:
- """Test that fragments received out of order are reassembled correctly."""
+ """Test that fragments are concatenated in receive order."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
leadmsg_payload = (
'LeadMessage'
@@ -664,7 +650,7 @@ async def test_receive_file_fragmented_out_of_order() -> None:
)
leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
- # Create blocks out of order (IDs: 102, 100, 101)
+ # Enqueue blocks out of ID order: 102, 100, 101
block2 = OmniLogicMessage(102, MessageType.MSP_BLOCKMESSAGE)
block2.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"third"
@@ -674,40 +660,33 @@ async def test_receive_file_fragmented_out_of_order() -> None:
block1 = OmniLogicMessage(101, MessageType.MSP_BLOCKMESSAGE)
block1.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"second"
- await protocol.data_queue.put(leadmsg)
- await protocol.data_queue.put(block2) # Out of order
- await protocol.data_queue.put(block0)
- await protocol.data_queue.put(block1)
+ await protocol._receive_queue.put(leadmsg)
+ await protocol._receive_queue.put(block2)
+ await protocol._receive_queue.put(block0)
+ await protocol._receive_queue.put(block1)
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock):
- result = await protocol._receive_file()
+ raw_data, _compressed = await protocol._receive_response()
- # Should be reassembled in ID order
- assert result == "firstsecondthird"
+ # Blocks are assembled in receive order, not by ID
+ assert raw_data == b"thirdfirstsecond"
@pytest.mark.asyncio
async def test_receive_file_fragmented_invalid_leadmessage() -> None:
- """Test that invalid LeadMessage XML raises error."""
+ """Test that invalid LeadMessage XML raises a parse error."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- # Create LeadMessage with invalid XML
leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload="invalid xml")
- await protocol.data_queue.put(leadmsg)
+ await protocol._receive_queue.put(leadmsg)
- with (
- patch.object(protocol, "_send_ack", new_callable=AsyncMock),
- pytest.raises(OmniFragmentationError, match="Failed to parse LeadMessage"),
- ):
- await protocol._receive_file()
+ with pytest.raises(ET.ParseError):
+ await protocol._receive_response()
@pytest.mark.asyncio
async def test_receive_file_fragmented_timeout_waiting() -> None:
- """Test timeout while waiting for fragments."""
+ """Test that OmniTimeoutError is raised when no fragment arrives within the timeout."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
leadmsg_payload = (
'LeadMessage'
@@ -716,46 +695,27 @@ async def test_receive_file_fragmented_timeout_waiting() -> None:
""
)
leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
+ await protocol._receive_queue.put(leadmsg)
+ # No block messages — _receive_next_message will time out
- await protocol.data_queue.put(leadmsg)
- # Don't put any BlockMessages - will timeout
-
- with (
- patch.object(protocol, "_send_ack", new_callable=AsyncMock),
- pytest.raises(OmniFragmentationError, match="Timeout receiving fragment"),
- ):
- await protocol._receive_file()
+ with pytest.raises(OmniTimeoutError):
+ await protocol._receive_response()
@pytest.mark.asyncio
async def test_receive_file_fragmented_max_wait_time_exceeded() -> None:
- """Test that MAX_FRAGMENT_WAIT_TIME timeout is enforced."""
+ """Test that _receive_next_message raises OmniTimeoutError when MAX_FRAGMENT_WAIT_TIME elapses."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
- leadmsg_payload = (
- 'LeadMessage'
- '100324'
- '20'
- ""
- )
- leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
-
- await protocol.data_queue.put(leadmsg)
-
- # Mock time to simulate timeout
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock), patch("time.time") as mock_time:
- mock_time.side_effect = [0, 31] # Start at 0, then 31 seconds later (> 30s max)
-
- with pytest.raises(OmniFragmentationError, match="Timeout waiting for fragments"):
- await protocol._receive_file()
+ # Patch asyncio.timeout to raise TimeoutError immediately, simulating the deadline being exceeded
+ with patch("asyncio.timeout", side_effect=TimeoutError), pytest.raises(OmniTimeoutError):
+ await protocol._receive_next_message()
@pytest.mark.asyncio
async def test_receive_file_fragmented_ignores_non_block_messages(caplog: pytest.LogCaptureFixture) -> None:
- """Test that non-BlockMessages during fragmentation are ignored."""
+ """Test that non-BlockMessages during fragment reassembly are warned and ignored."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
leadmsg_payload = (
'LeadMessage'
@@ -765,55 +725,37 @@ async def test_receive_file_fragmented_ignores_non_block_messages(caplog: pytest
)
leadmsg = OmniLogicMessage(100, MessageType.MSP_LEADMESSAGE, payload=leadmsg_payload)
- # Put LeadMessage, then an ACK (should be ignored), then the actual block
- ack_msg = OmniLogicMessage(999, MessageType.ACK)
+ # A configuration update message (not a block) should be skipped during reassembly
+ interloper = OmniLogicMessage(999, MessageType.MSP_CONFIGURATIONUPDATE)
+ interloper.payload = b""
+
block1 = OmniLogicMessage(101, MessageType.MSP_BLOCKMESSAGE)
block1.payload = b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"data"
- await protocol.data_queue.put(leadmsg)
- await protocol.data_queue.put(ack_msg)
- await protocol.data_queue.put(block1)
+ await protocol._receive_queue.put(leadmsg)
+ await protocol._receive_queue.put(interloper)
+ await protocol._receive_queue.put(block1)
- with patch.object(protocol, "_send_ack", new_callable=AsyncMock), caplog.at_level("DEBUG"):
- result = await protocol._receive_file()
+ with caplog.at_level("WARNING"):
+ raw_data, _compressed = await protocol._receive_response()
- assert any("other than a blockmessage" in r.message for r in caplog.records)
- assert result == "data"
+ assert any("expected block message" in r.message for r in caplog.records)
+ assert raw_data == b"data"
@pytest.mark.asyncio
async def test_wait_for_ack_cancels_pending_tasks() -> None:
- """Test that pending tasks are properly cancelled in _wait_for_ack to avoid warnings."""
+ """Test that _ack_futures are cleaned up after _send_with_retry completes."""
protocol = OmniLogicProtocol()
- protocol.transport = MagicMock()
-
- # Track tasks created during _wait_for_ack
- created_tasks: list[asyncio.Task[Any]] = []
- original_create_task = asyncio.create_task
-
- def track_create_task(coro: Any) -> asyncio.Task[Any]:
- task: asyncio.Task[Any] = original_create_task(coro)
- created_tasks.append(task)
- return task
-
- # Queue up an ACK message
- ack_msg = OmniLogicMessage(42, MessageType.ACK)
- await protocol.data_queue.put(ack_msg)
-
- # Patch create_task to track tasks
- with patch("asyncio.create_task", side_effect=track_create_task):
- await protocol._wait_for_ack(42)
+ protocol._transport = MagicMock()
- # Give the event loop a chance to process cancellation
- await asyncio.sleep(0)
+ def resolve_ack_on_send(data: bytes) -> None:
+ msg = OmniLogicMessage.from_bytes(data)
+ protocol._resolve_ack(OmniLogicMessage(msg.id, MessageType.ACK))
- # Should have created 2 tasks (data_task and error_task)
- assert len(created_tasks) == 2
+ protocol._transport.sendto.side_effect = resolve_ack_on_send
- # One should be done (the data_task that got the ACK)
- # One should be cancelled (the error_task that was waiting)
- done_tasks = [t for t in created_tasks if t.done() and not t.cancelled()]
- cancelled_tasks = [t for t in created_tasks if t.cancelled()]
+ await protocol._send_with_retry(MessageType.REQUEST_CONFIGURATION, "")
- assert len(done_tasks) == 1, "Expected exactly one task to complete normally"
- assert len(cancelled_tasks) == 1, "Expected exactly one task to be cancelled"
+ # After a successful send, all futures should be cleaned up from _ack_futures
+ assert len(protocol._ack_futures) == 0