diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py index 98eb886c..e7b6bf93 100644 --- a/plugboard-schemas/plugboard_schemas/_validation.py +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -18,6 +18,9 @@ from ._validator_registry import validator +_SYSTEM_STOP_EVENT = "system_stop" + + def _build_component_graph( connectors: dict[str, dict[str, _t.Any]], ) -> dict[str, set[str]]: @@ -100,6 +103,9 @@ def validate_all_inputs_connected( all_inputs = set(io.get("inputs", [])) connected = connected_inputs.get(comp_name, set()) unconnected = all_inputs - connected + if unconnected: + event_covered_fields = set().union(*io.get("event_field_coverage", {}).values()) + unconnected -= event_covered_fields if unconnected: errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}") return errors diff --git a/plugboard/component/component.py b/plugboard/component/component.py index 6fe0ad20..608689b1 100644 --- a/plugboard/component/component.py +++ b/plugboard/component/component.py @@ -95,6 +95,7 @@ def __init__( initial_values=self._initial_values, input_events=self.__class__.io.input_events, output_events=self.__class__.io.output_events, + event_field_coverage=self.__class__.io.event_field_coverage, namespace=self.name, component=self, ) @@ -356,7 +357,7 @@ async def _wrapper() -> None: raise e self._bind_outputs() await self.io.write() - self._field_inputs_ready = False + self._reset_input_trackers() await self._set_status(Status.WAITING, publish=not self._is_running) return _wrapper @@ -365,6 +366,11 @@ async def _wrapper() -> None: def _has_field_inputs(self) -> bool: return len(self.io.inputs) > 0 + @property + def _has_connected_field_inputs(self) -> bool: + """Whether any declared field inputs are connected via input channels.""" + return self.io.has_connected_field_inputs + @cached_property def _has_event_inputs(self) -> bool: input_events = set([evt.safe_type() for evt in self.io.input_events]) @@ -409,7 +415,7 @@ async def _io_read_with_status_check(self) -> None: task.cancel() for task in done: exc = task.exception() - if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0: + if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs: await self.io.close() # Call close for final wait and flush event buffer elif exc is not None: raise exc @@ -422,7 +428,7 @@ async def _periodic_status_check(self) -> None: # TODO : Eventually producer graph update will be event driven. For now, # : the update is performed periodically, so it's called here along # : with the status check. - if len(self.io.inputs) == 0: + if not self._has_connected_field_inputs: await self._update_producer_graph() async def _status_check(self) -> None: @@ -455,8 +461,11 @@ def _bind_inputs(self) -> None: for field in self.io.inputs: field_default = getattr(self, field, None) value = self._field_inputs.get(field, field_default) - setattr(self, field, value) + super().__setattr__(field, value) + + def _reset_input_trackers(self) -> None: self._field_inputs = {} + self._field_inputs_ready = False def _bind_outputs(self) -> None: """Binds component fields to output fields.""" diff --git a/plugboard/component/io_controller.py b/plugboard/component/io_controller.py index 7500aee2..870c49c6 100644 --- a/plugboard/component/io_controller.py +++ b/plugboard/component/io_controller.py @@ -38,6 +38,7 @@ def __init__( initial_values: _t.Optional[dict[str, _t.Iterable]] = None, input_events: _t.Optional[list[_t.Type[Event]]] = None, output_events: _t.Optional[list[_t.Type[Event]]] = None, + event_field_coverage: _t.Optional[dict[str, list[str]]] = None, namespace: str = IO_NS_UNSET, component: _t.Optional[Component] = None, ) -> None: @@ -47,6 +48,7 @@ def __init__( self.initial_values = initial_values or {} self.input_events = input_events or [] self.output_events = output_events or [] + self.event_field_coverage = event_field_coverage or {} if set(self.initial_values.keys()) - set(self.inputs): raise ValueError("Initial values must be for input fields only.") self._component = component @@ -86,8 +88,9 @@ def is_closed(self) -> bool: """Returns `True` if the `IOController` is closed, `False` otherwise.""" return self._is_closed - @cached_property - def _has_field_inputs(self) -> bool: + @property + def has_connected_field_inputs(self) -> bool: + """Returns whether any field inputs are connected via channels.""" return len(self._input_channels) > 0 @cached_property @@ -96,7 +99,7 @@ def _has_event_inputs(self) -> bool: @cached_property def _has_inputs(self) -> bool: - return self._has_field_inputs or self._has_event_inputs + return self.has_connected_field_inputs or self._has_event_inputs async def read(self, timeout: float | None = None) -> None: """Reads data and/or events from input channels. @@ -139,7 +142,7 @@ async def read(self, timeout: float | None = None) -> None: def _set_read_tasks(self) -> list[asyncio.Task]: read_tasks: list[asyncio.Task] = [] - if self._has_field_inputs: + if self.has_connected_field_inputs: if _fields_read_task not in self._read_tasks: read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task) self._read_tasks[_fields_read_task] = read_fields_task @@ -374,7 +377,7 @@ def _add_channel_for_event( def _create_input_field_group_tasks(self) -> None: """Groups input field channels by field name and launches read tasks for group inputs.""" - if not self._has_field_inputs: + if not self.has_connected_field_inputs: return field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list) for key, chan in self._input_channels.items(): @@ -410,6 +413,7 @@ def dict(self) -> dict[str, _t.Any]: # noqa: D102 "input_events": [e.safe_type() for e in self.input_events], "output_events": [e.safe_type() for e in self.output_events], "initial_values": {k: list(v) for k, v in self._initial_values.items()}, + "event_field_coverage": {k: list(v) for k, v in self.event_field_coverage.items()}, } diff --git a/plugboard/events/event.py b/plugboard/events/event.py index 2becb56d..858e26a5 100644 --- a/plugboard/events/event.py +++ b/plugboard/events/event.py @@ -75,9 +75,28 @@ def safe_type(cls, event_type: _t.Optional[str] = None) -> str: """Returns a safe event type string for use in broker topic strings.""" return (event_type or cls.type).replace(".", "_").replace("-", "_") + @_t.overload @classmethod - def handler(cls, method: AsyncCallable) -> AsyncCallable: + def handler(cls, method: AsyncCallable) -> AsyncCallable: ... + + @_t.overload + @classmethod + def handler( + cls, *, populates_fields: _t.Optional[list[str]] = None + ) -> _t.Callable[[AsyncCallable], AsyncCallable]: ... + + @classmethod + def handler( + cls, + method: _t.Optional[AsyncCallable] = None, + *, + populates_fields: _t.Optional[list[str]] = None, + ) -> _t.Union[AsyncCallable, _t.Callable[[AsyncCallable], AsyncCallable]]: """Registers a class method as an event handler.""" + if method is None: + # Invoked as @Event.handler(populates_fields=[...]) + return EventHandlers.add(cls, populates_fields=populates_fields) + # Invoked as @Event.handler return EventHandlers.add(cls)(method) diff --git a/plugboard/events/event_handlers.py b/plugboard/events/event_handlers.py index 344522ce..3eb8223d 100644 --- a/plugboard/events/event_handlers.py +++ b/plugboard/events/event_handlers.py @@ -18,11 +18,16 @@ class EventHandlers: # pragma: no cover _handlers: _t.ClassVar[dict[str, dict[str, AsyncCallable]]] = defaultdict(dict) @classmethod - def add(cls, event: _t.Type[Event] | Event) -> _t.Callable[[AsyncCallable], AsyncCallable]: + def add( + cls, + event: _t.Type[Event] | Event, + populates_fields: _t.Optional[list[str]] = None, + ) -> _t.Callable[[AsyncCallable], AsyncCallable]: """Decorator that registers class methods as handlers for specific event types. Args: event: Event class this handler processes + populates_fields: Optional list of fields that the handler populates Returns: Callable: Decorated method @@ -31,6 +36,13 @@ def add(cls, event: _t.Type[Event] | Event) -> _t.Callable[[AsyncCallable], Asyn def decorator(method: AsyncCallable) -> AsyncCallable: class_path = cls._get_class_path_for_method(method) cls._handlers[class_path][event.type] = method + if populates_fields is not None: + comp_cls = method.__self__.__class__ + if not hasattr(comp_cls, "io"): + raise ValueError( + "populates_fields must be specified on method of Component subclass." + ) + comp_cls.io.event_field_coverage[event.type] = populates_fields return method return decorator @@ -43,6 +55,12 @@ def _get_class_path_for_method(method: AsyncCallable) -> str: class_name = qualname_parts[-2] # Last part is the method name return f"{module_name}.{class_name}" + @staticmethod + def _iter_mro(_class: _t.Type) -> _t.Iterator[str]: + """Iterate over class MRO, yielding fully qualified class paths.""" + for base_class in _class.__mro__: + yield f"{base_class.__module__}.{base_class.__name__}" + @classmethod def get(cls, _class: _t.Type, event: _t.Type[Event] | Event) -> AsyncCallable: """Retrieve a handler for a specific class and event type. @@ -57,10 +75,11 @@ def get(cls, _class: _t.Type, event: _t.Type[Event] | Event) -> AsyncCallable: Raises: KeyError: If no handler found for class or event type """ + store = cls._handlers for base_class in _class.__mro__: base_path = f"{base_class.__module__}.{base_class.__name__}" - if base_path in cls._handlers and event.type in cls._handlers[base_path]: - return cls._handlers[base_path][event.type] + if base_path in store and event.type in store[base_path]: + return store[base_path][event.type] raise KeyError( f"No handler found for class '{_class.__name__}' and event type '{event.type}'" ) diff --git a/plugboard/library/data_writer.py b/plugboard/library/data_writer.py index 96d14538..f281a531 100644 --- a/plugboard/library/data_writer.py +++ b/plugboard/library/data_writer.py @@ -43,6 +43,7 @@ def __init__( **kwargs: Additional keyword arguments for [`Component`][plugboard.component.Component]. """ super().__init__(**kwargs) + # Use a single buffer to track everything self._buffer: dict[str, deque] = defaultdict(deque) self._chunk_size = chunk_size self.io = IOController( @@ -76,18 +77,39 @@ async def _convert(self, data: dict[str, deque]) -> _t.Any: def _bind_inputs(self) -> None: """Binds input fields to component fields and append to internal buffer.""" super()._bind_inputs() - for field in self.io.inputs: + for field in self._field_inputs: value = getattr(self, field, None) self._buffer[field].append(value) + @property + def _completed_rows(self) -> int: + """Calculates how many fully formed rows exist in the buffer.""" + if not self.io.inputs: + return 0 + return min((len(self._buffer[f]) for f in self.io.inputs), default=0) + + @property + def _can_step(self) -> bool: + """We can step if we have at least one fully formed row.""" + return self._completed_rows > 0 + async def _save_chunk(self) -> None: - """Write data from the buffer.""" + """Write completed data rows from the buffer.""" + completed_rows = self._completed_rows + if completed_rows == 0: + return + if self._task is not None: await self._task - # Create task to save next chunk of data - chunk = await self._convert(self._buffer) + + # Extract only the completed rows into a new chunk + chunk_data = { + field: deque([self._buffer[field].popleft() for _ in range(completed_rows)]) + for field in self.io.inputs + } + + chunk = await self._convert(chunk_data) self._task = asyncio.create_task(self._save(chunk)) - self._buffer = defaultdict(deque) async def step(self) -> None: """Trigger save when buffer is at target size.""" diff --git a/tests/integration/test_process_with_components_run.py b/tests/integration/test_process_with_components_run.py index fe047ae8..2a16c41f 100644 --- a/tests/integration/test_process_with_components_run.py +++ b/tests/integration/test_process_with_components_run.py @@ -23,6 +23,7 @@ ) from plugboard.events import Event from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError +from plugboard.library import FileWriter from plugboard.process import LocalProcess, Process, RayProcess from plugboard.schemas import ConnectorSpec, Status from tests.conftest import ComponentTestHelper, zmq_connector_cls @@ -459,6 +460,99 @@ async def test_event_driven_process_shutdown( await process.destroy() +class MessageEventData(BaseModel): + """Data for a message event.""" + + message: str + + +class MessageEvent(Event): + """Event carrying a file-writer message.""" + + type: _t.ClassVar[str] = "message_event" + data: MessageEventData + + +class MessageEventGenerator(ComponentTestHelper): + """Produces a fixed number of message events.""" + + io = IO(output_events=[MessageEvent]) + + def __init__( + self, + iters: int, + *args: _t.Any, + delay: float = 0.0, + start: int = 0, + stride: int = 1, + **kwargs: _t.Any, + ) -> None: + super().__init__(*args, **kwargs) + self._iters = iters + self._delay = delay + self._start = start + self._stride = stride + + async def init(self) -> None: + await super().init() + self._seq = iter(range(self._start, self._start + self._iters * self._stride, self._stride)) + + async def step(self) -> None: + # Optional delay to simulate staggered event arrival + if self._delay > 0.0: + await asyncio.sleep(self._delay) + try: + idx = next(self._seq) + except StopIteration: + await self.io.close() + else: + evt = MessageEvent( + source=self.name, + data=MessageEventData(message=f"Message {idx}"), + ) + self.io.queue_event(evt) + await super().step() + + +class EventReaderFileWriter(FileWriter): + """`FileWriter` variant that adds event handling instead of a connector for `message`.""" + + io = IO(input_events=[MessageEvent]) + + @MessageEvent.handler + async def handle_message(self, event: MessageEvent) -> None: + self.message = event.data.message + + +@pytest.mark.asyncio +async def test_event_driven_file_writer_reuse(tmp_path: Path) -> None: + """Test that field-input components can be reused in event-driven processes.""" + output_path = tmp_path / "output_messages.csv" + components = [ + MessageEventGenerator(iters=3, name="message_event_generator"), + EventReaderFileWriter( + path=output_path, + name="event_reader_file_writer", + field_names=["message"], + ), + ] + event_connectors = AsyncioConnector.builder().build_event_connectors(components) + process = LocalProcess(components=components, connectors=event_connectors) + + await process.init() + await process.run() + + assert process.status == Status.COMPLETED + assert output_path.read_text().splitlines() == [ + "message", + "Message 0", + "Message 1", + "Message 2", + ] + + await process.destroy() + + _SHORT_TIMEOUT = 0.1 @@ -536,3 +630,88 @@ async def test_constraint_error_stops_background_status_check() -> None: ) await process.destroy() + + +class StaggeredEventFileWriter(FileWriter): + """`FileWriter` variant that adds event handling instead of a connector for `message`.""" + + io = IO(input_events=[MessageEvent]) + + def __init__(self, *args: _t.Any, field_names: list[str], **kwargs: _t.Any) -> None: + super().__init__(*args, field_names=field_names, **kwargs) + self.step_count: int = 0 + self.step_for_message: dict[str, int] = {} + + @MessageEvent.handler + async def handle_message(self, event: MessageEvent) -> None: + msg = event.data.message + match event.source: + case "mg1": + self.mg1 = msg + case "mg2": + self.mg2 = msg + case "mg3": + self.mg3 = msg + case _: + raise ValueError(f"Unexpected event source: {event.source}") + self.step_for_message[msg] = self.step_count + self.step_count += 1 + + +@pytest.mark.asyncio +@pytest_cases.parametrize( + "process_cls, connector_cls", + [ + (LocalProcess, AsyncioConnector), + ], +) +async def test_data_writer_handles_staggered_input_events( + process_cls: type[Process], connector_cls: type[Connector], tmp_path: Path, ray_ctx: None +) -> None: + """Test that a FileWriter can handle input events arriving in different steps. + + Input messages with data for different fields may arrive in different steps. The FileWriter + should write out a new row only when all required fields have received data, and should not + overwrite field values if only a subset of fields receive new data in a step. + """ + output_path = tmp_path / "staggered_output_messages.csv" + + writer = StaggeredEventFileWriter( + path=output_path, field_names=["mg1", "mg2", "mg3"], name="writer" + ) + components = [ + # 3 inputs with different delays + MessageEventGenerator(iters=10, delay=0.005, start=0, stride=3, name="mg1"), + MessageEventGenerator(iters=10, delay=0.010, start=1, stride=3, name="mg2"), + MessageEventGenerator(iters=10, delay=0.020, start=2, stride=3, name="mg3"), + writer, + ] + + async with process_cls( + components=components, + connectors=AsyncioConnector.builder().build_event_connectors(components), + ) as process: + await process.run() + + with output_path.open() as f: + content = f.read().splitlines() + + assert len(content) == 11 # header + 10 rows of data + assert content[0] == "mg1,mg2,mg3" + assert content[1] == "Message 0,Message 1,Message 2" + assert content[2] == "Message 3,Message 4,Message 5" + assert content[3] == "Message 6,Message 7,Message 8" + assert content[4] == "Message 9,Message 10,Message 11" + assert content[5] == "Message 12,Message 13,Message 14" + assert content[6] == "Message 15,Message 16,Message 17" + assert content[7] == "Message 18,Message 19,Message 20" + assert content[8] == "Message 21,Message 22,Message 23" + assert content[9] == "Message 24,Message 25,Message 26" + assert content[10] == "Message 27,Message 28,Message 29" + + # Verify that messages from different generators were received in different steps + assert writer.step_count == 30 + assert len(writer.step_for_message) == 30 + assert len(set(writer.step_for_message.values())) == 30, ( + "Expected each message to be received in a different step" + ) diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py index 02e0a4d2..df20132e 100644 --- a/tests/unit/test_process_validation.py +++ b/tests/unit/test_process_validation.py @@ -95,6 +95,7 @@ def _make_component( outputs: list[str] | None = None, input_events: list[str] | None = None, output_events: list[str] | None = None, + event_field_coverage: dict[str, list[str]] | None = None, initial_values: dict[str, _t.Any] | None = None, ) -> dict[str, _t.Any]: """Build a component dict matching process.dict() format.""" @@ -108,6 +109,7 @@ def _make_component( "outputs": outputs or [], "input_events": input_events or [], "output_events": output_events or [], + "event_field_coverage": event_field_coverage or {}, "initial_values": initial_values or {}, }, } @@ -303,6 +305,22 @@ def test_no_inputs_no_errors(self) -> None: errors = validate_all_inputs_connected(pd) assert errors == [] + def test_event_covered_fields(self) -> None: + """Unconnected inputs are allowed when non-system input events can populate them.""" + pd = _make_process_dict( + components={ + "producer": _make_component("producer", output_events=["message_event"]), + "writer": _make_component( + "writer", + inputs=["message"], + input_events=["system_stop", "message_event"], + event_field_coverage={"message_event": ["message"]}, + ), + }, + ) + errors = validate_all_inputs_connected(pd) + assert errors == [] + # --------------------------------------------------------------------------- # Tests for validate_input_events