Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,11 @@ def _register_task(
if task.broker != self:
raise TaskBrokerMismatchError(broker=task.broker)
self.local_task_registry[task_name] = task

async def __aenter__(self) -> None:
"""Starts the broker as ctx manager."""
await self.startup()

async def __aexit__(self, *args: object, **kwargs: Any) -> None:
"""Shuts down the broker as ctx manager."""
await self.shutdown()
107 changes: 107 additions & 0 deletions tests/abc/test_broker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections.abc import AsyncGenerator
from copy import copy

import pytest

from taskiq.abc.broker import AsyncBroker
from taskiq.decor import AsyncTaskiqDecoratedTask
from taskiq.events import TaskiqEvents
from taskiq.message import BrokerMessage
from taskiq.state import TaskiqState


class _TestBroker(AsyncBroker):
Expand Down Expand Up @@ -76,3 +80,106 @@ async def test_task() -> None: ...
assert "another_label" in test_kicker.labels

assert test_task.labels == old_labels


@pytest.mark.anyio
async def test_async_context_manager_enter() -> None:
"""Test that __aenter__ calls startup."""
broker = _TestBroker()
startup_called = False

@broker.on_event(TaskiqEvents.CLIENT_STARTUP)
async def track_startup(state: TaskiqState) -> None:
nonlocal startup_called
startup_called = True

async with broker:
assert startup_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit() -> None:
"""Test that __aexit__ calls shutdown."""
broker = _TestBroker()
shutdown_called = False

@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

async with broker:
pass

assert shutdown_called is True


@pytest.mark.anyio
async def test_async_context_manager_enter_worker() -> None:
"""Test that __aenter__ calls worker startup when is_worker_process is True."""
broker = _TestBroker()
broker.is_worker_process = True
startup_called = False

@broker.on_event(TaskiqEvents.WORKER_STARTUP)
async def track_startup(state: TaskiqState) -> None:
nonlocal startup_called
startup_called = True

async with broker:
assert startup_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit_worker() -> None:
"""Test that __aexit__ calls worker shutdown when is_worker_process is True."""
broker = _TestBroker()
broker.is_worker_process = True
shutdown_called = False

@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

async with broker:
pass

assert shutdown_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit_on_exception() -> None:
"""Test that __aexit__ calls shutdown even if exception is raised."""
broker = _TestBroker()
shutdown_called = False

@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

with pytest.raises(ValueError, match="Test exception"):
async with broker:
raise ValueError("Test exception")

assert shutdown_called is True


@pytest.mark.anyio
async def test_async_context_manager_exit_worker_on_exception() -> None:
"""Test that __aexit__ calls worker shutdown even if exception is raised."""
broker = _TestBroker()
broker.is_worker_process = True
shutdown_called = False

@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def track_shutdown(state: TaskiqState) -> None:
nonlocal shutdown_called
shutdown_called = True

with pytest.raises(ValueError, match="Test exception"):
async with broker:
raise ValueError("Test exception")

assert shutdown_called is True
Loading