From 323d09dd122f0dac5f272810f0aad1c16b8a2391 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Mon, 4 May 2026 12:11:42 +0200 Subject: [PATCH 1/2] Added __aexit__ and __aenter__ support. --- taskiq/abc/broker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index c0902371..1c2105c7 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -533,3 +533,11 @@ def _register_task( if task.broker != self: raise TaskBrokerMismatchError(broker=task.broker) self.local_task_registry[task_name] = task + + async def __aenter__(self) -> None: + """Satarts the broker as ctx manager.""" + await self.startup() + + async def __aexit__(self, *args: object, **kwargs: Any) -> None: + """Shuts down the broker as ctx manager.""" + await self.shutdown() From c7deafd123bbc7aae139a3c5f201bfeb14d44f05 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Mon, 4 May 2026 12:21:46 +0200 Subject: [PATCH 2/2] Added tests for broker ctx manager. --- taskiq/abc/broker.py | 2 +- tests/abc/test_broker.py | 107 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 1c2105c7..ea2e86c0 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -535,7 +535,7 @@ def _register_task( self.local_task_registry[task_name] = task async def __aenter__(self) -> None: - """Satarts the broker as ctx manager.""" + """Starts the broker as ctx manager.""" await self.startup() async def __aexit__(self, *args: object, **kwargs: Any) -> None: diff --git a/tests/abc/test_broker.py b/tests/abc/test_broker.py index 5c536cfb..8d39ea50 100644 --- a/tests/abc/test_broker.py +++ b/tests/abc/test_broker.py @@ -1,9 +1,13 @@ from collections.abc import AsyncGenerator from copy import copy +import pytest + from taskiq.abc.broker import AsyncBroker from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.events import TaskiqEvents from taskiq.message import BrokerMessage +from taskiq.state import TaskiqState class _TestBroker(AsyncBroker): @@ -76,3 +80,106 @@ async def test_task() -> None: ... assert "another_label" in test_kicker.labels assert test_task.labels == old_labels + + +@pytest.mark.anyio +async def test_async_context_manager_enter() -> None: + """Test that __aenter__ calls startup.""" + broker = _TestBroker() + startup_called = False + + @broker.on_event(TaskiqEvents.CLIENT_STARTUP) + async def track_startup(state: TaskiqState) -> None: + nonlocal startup_called + startup_called = True + + async with broker: + assert startup_called is True + + +@pytest.mark.anyio +async def test_async_context_manager_exit() -> None: + """Test that __aexit__ calls shutdown.""" + broker = _TestBroker() + shutdown_called = False + + @broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN) + async def track_shutdown(state: TaskiqState) -> None: + nonlocal shutdown_called + shutdown_called = True + + async with broker: + pass + + assert shutdown_called is True + + +@pytest.mark.anyio +async def test_async_context_manager_enter_worker() -> None: + """Test that __aenter__ calls worker startup when is_worker_process is True.""" + broker = _TestBroker() + broker.is_worker_process = True + startup_called = False + + @broker.on_event(TaskiqEvents.WORKER_STARTUP) + async def track_startup(state: TaskiqState) -> None: + nonlocal startup_called + startup_called = True + + async with broker: + assert startup_called is True + + +@pytest.mark.anyio +async def test_async_context_manager_exit_worker() -> None: + """Test that __aexit__ calls worker shutdown when is_worker_process is True.""" + broker = _TestBroker() + broker.is_worker_process = True + shutdown_called = False + + @broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) + async def track_shutdown(state: TaskiqState) -> None: + nonlocal shutdown_called + shutdown_called = True + + async with broker: + pass + + assert shutdown_called is True + + +@pytest.mark.anyio +async def test_async_context_manager_exit_on_exception() -> None: + """Test that __aexit__ calls shutdown even if exception is raised.""" + broker = _TestBroker() + shutdown_called = False + + @broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN) + async def track_shutdown(state: TaskiqState) -> None: + nonlocal shutdown_called + shutdown_called = True + + with pytest.raises(ValueError, match="Test exception"): + async with broker: + raise ValueError("Test exception") + + assert shutdown_called is True + + +@pytest.mark.anyio +async def test_async_context_manager_exit_worker_on_exception() -> None: + """Test that __aexit__ calls worker shutdown even if exception is raised.""" + broker = _TestBroker() + broker.is_worker_process = True + shutdown_called = False + + @broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) + async def track_shutdown(state: TaskiqState) -> None: + nonlocal shutdown_called + shutdown_called = True + + with pytest.raises(ValueError, match="Test exception"): + async with broker: + raise ValueError("Test exception") + + assert shutdown_called is True