From f50eb759a0ae6f49e5be5612b8c77033a5a9fb52 Mon Sep 17 00:00:00 2001 From: Adi Berkowitz Date: Tue, 23 Dec 2025 00:12:20 -0500 Subject: [PATCH 1/4] Add support for jitter so that we can ensure evenly distributed tasks don't cause all workers to restart at same time --- taskiq/api/receiver.py | 3 +++ taskiq/brokers/inmemory_broker.py | 2 ++ taskiq/cli/worker/args.py | 9 +++++++++ taskiq/cli/worker/run.py | 1 + taskiq/receiver/receiver.py | 8 +++++++- 5 files changed, 22 insertions(+), 1 deletion(-) diff --git a/taskiq/api/receiver.py b/taskiq/api/receiver.py index 72c6cbcf..03b84ca3 100644 --- a/taskiq/api/receiver.py +++ b/taskiq/api/receiver.py @@ -15,6 +15,7 @@ async def run_receiver_task( sync_workers: int | None = None, validate_params: bool = True, max_async_tasks: int = 100, + max_async_tasks_jitter: int = 0, max_prefetch: int = 0, propagate_exceptions: bool = True, run_startup: bool = False, @@ -43,6 +44,7 @@ async def run_receiver_task( or processes in processpool that runs sync tasks. :param validate_params: whether to validate params or not. :param max_async_tasks: maximum number of simultaneous async tasks. + :param max_async_tasks_jitter: random jitter to add to max_async_tasks. :param max_prefetch: maximum number of tasks to prefetch. :param propagate_exceptions: whether to propagate exceptions in generators or not. :param run_startup: whether to run startup function or not. @@ -79,6 +81,7 @@ def on_exit(_: Receiver) -> None: run_startup=run_startup, validate_params=validate_params, max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, max_prefetch=max_prefetch, propagate_exceptions=propagate_exceptions, on_exit=on_exit, diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index b7d1e67e..0a7cc98e 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -127,6 +127,7 @@ def __init__( max_stored_results: int = 100, cast_types: bool = True, max_async_tasks: int = 30, + max_async_tasks_jitter: int = 0, propagate_exceptions: bool = True, await_inplace: bool = False, ) -> None: @@ -140,6 +141,7 @@ def __init__( executor=self.executor, validate_params=cast_types, max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, propagate_exceptions=propagate_exceptions, ) self.await_inplace = await_inplace diff --git a/taskiq/cli/worker/args.py b/taskiq/cli/worker/args.py index fa922d7d..df3eb957 100644 --- a/taskiq/cli/worker/args.py +++ b/taskiq/cli/worker/args.py @@ -44,6 +44,7 @@ class WorkerArgs: reload_dirs: list[str] = field(default_factory=list) no_gitignore: bool = False max_async_tasks: int = 100 + max_async_tasks_jitter: int = 0 receiver: str = "taskiq.receiver:Receiver" receiver_arg: list[tuple[str, str]] = field(default_factory=list) max_prefetch: int = 0 @@ -210,6 +211,14 @@ def from_cli( default=100, help="Maximum simultaneous async tasks per worker process. ", ) + parser.add_argument( + "--max-async-tasks-jitter", + type=int, + dest="max_async_tasks_jitter", + default=0, + help="Add random jitter (0 to this value) to max-async-tasks to prevent " + "all workers from closing at the same time. ", + ) parser.add_argument( "--max-prefetch", type=int, diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 53cef7c0..24d8f8db 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -165,6 +165,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: executor=pool, validate_params=not args.no_parse, max_async_tasks=args.max_async_tasks, + max_async_tasks_jitter=args.max_async_tasks_jitter, max_prefetch=args.max_prefetch, propagate_exceptions=not args.no_propagate_errors, ack_type=args.ack_type, diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index f54f2259..2411669a 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -2,6 +2,7 @@ import contextvars import functools import inspect +import random import sys from collections.abc import Callable from concurrent.futures import Executor, ProcessPoolExecutor @@ -55,6 +56,7 @@ def __init__( executor: Executor | None = None, validate_params: bool = True, max_async_tasks: "int | None" = None, + max_async_tasks_jitter: int = 0, max_prefetch: int = 0, propagate_exceptions: bool = True, run_startup: bool = True, @@ -80,7 +82,11 @@ def __init__( self._prepare_task(task.task_name, task.original_func) self.sem: asyncio.Semaphore | None = None if max_async_tasks is not None and max_async_tasks > 0: - self.sem = asyncio.Semaphore(max_async_tasks) + # Apply jitter to prevent all workers from hitting the limit simultaneously + actual_limit = max_async_tasks + if max_async_tasks_jitter > 0: + actual_limit = max_async_tasks + random.randint(0, max_async_tasks_jitter) + self.sem = asyncio.Semaphore(actual_limit) else: logger.warning( "Setting unlimited number of async tasks " From dd8596d9e2c8518d6e2d214291072898971b0bfc Mon Sep 17 00:00:00 2001 From: Adi Berkowitz Date: Tue, 23 Dec 2025 00:24:41 -0500 Subject: [PATCH 2/4] Add test covg --- taskiq/receiver/receiver.py | 6 +++- tests/receiver/test_receiver.py | 58 ++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 2411669a..c4b06071 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -85,7 +85,11 @@ def __init__( # Apply jitter to prevent all workers from hitting the limit simultaneously actual_limit = max_async_tasks if max_async_tasks_jitter > 0: - actual_limit = max_async_tasks + random.randint(0, max_async_tasks_jitter) + # Using standard random for load distribution, not cryptography + actual_limit = max_async_tasks + random.randint( # noqa: S311 + 0, + max_async_tasks_jitter, + ) self.sem = asyncio.Semaphore(actual_limit) else: logger.warning( diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 0b0e976a..eeb29c11 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -2,6 +2,7 @@ import contextvars import random import time +import unittest.mock from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from functools import wraps @@ -24,13 +25,15 @@ def get_receiver( broker: AsyncBroker | None = None, no_parse: bool = False, max_async_tasks: int | None = None, + max_async_tasks_jitter: int = 0, ) -> Receiver: """ Returns receiver with custom broker and args. :param broker: broker, defaults to None :param no_parse: parameter to taskiq_args, defaults to False - :param cli_args: Taskiq worker CLI arguments. + :param max_async_tasks: maximum number of simultaneous async tasks. + :param max_async_tasks_jitter: random jitter to add to max_async_tasks. :return: new receiver. """ if broker is None: @@ -40,6 +43,7 @@ def get_receiver( executor=ThreadPoolExecutor(max_workers=10), validate_params=not no_parse, max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, ) @@ -544,3 +548,55 @@ async def task_no_result() -> str: assert resp.return_value == "some value" assert not broker._running_tasks assert wrapper_call is True + + +async def test_jitter_applied_to_semaphore() -> None: + """Test that jitter is correctly applied to max_async_tasks semaphore.""" + max_async_tasks = 100 + max_async_tasks_jitter = 10 + + # Test with jitter value of 0 (minimum) + with unittest.mock.patch("random.randint", return_value=0): + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, + ) + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + + # Test with jitter value of 5 (middle) + with unittest.mock.patch("random.randint", return_value=5): + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, + ) + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + 5 + + # Test with jitter value of 10 (maximum) + with unittest.mock.patch("random.randint", return_value=10): + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, + ) + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + 10 + + +async def test_jitter_zero_no_randomization() -> None: + """Test with zero jitter, semaphore value matches max_async_tasks.""" + max_async_tasks = 50 + + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=0, + ) + + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + + +async def test_no_semaphore_without_max_async_tasks() -> None: + """Test that semaphore is None when max_async_tasks is not set.""" + receiver = get_receiver(max_async_tasks=None) + assert receiver.sem is None From 5acb282e8f85b77b4c323d9ce1c984f00de51b90 Mon Sep 17 00:00:00 2001 From: Adi Berkowitz Date: Wed, 24 Dec 2025 13:57:07 -0500 Subject: [PATCH 3/4] Update cli description to include jitter doc --- docs/guide/cli.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/guide/cli.md b/docs/guide/cli.md index b70ad70a..dae3b6bb 100644 --- a/docs/guide/cli.md +++ b/docs/guide/cli.md @@ -138,6 +138,7 @@ The number of signals before a hard kill can be configured with the `--hardkill- * `--log-level` is used to set a log level (default `INFO`). * `--log-format` is used to set a log format (default `%(asctime)s][%(name)s][%(levelname)-7s][%(processName)s] %(message)s`). * `--max-async-tasks` - maximum number of simultaneously running async tasks. +* `--max-async-tasks-jitter` – Randomly varies the max async task limit between --max-async-tasks and a jittered value, helping prevent simultaneous worker restarts. * `--max-prefetch` - number of tasks to be prefetched before execution. (Useful for systems with high message rates, but brokers should support acknowledgements). * `--max-threadpool-threads` - number of threads for sync function execution. * `--no-propagate-errors` - if this parameter is enabled, exceptions won't be thrown in generator dependencies. @@ -149,7 +150,7 @@ The number of signals before a hard kill can be configured with the `--hardkill- * `--shutdown-timeout` - maximum amount of time for graceful broker's shutdown in seconds (default 5). * `--wait-tasks-timeout` - if cannot read new messages from the broker or maximum number of tasks is reached, worker will wait for all current tasks to finish. This parameter sets the maximum amount of time to wait until shutdown. * `--hardkill-count` - Number of termination signals to the main process before performing a hardkill. - +* ## Scheduler Scheduler is used to schedule tasks as described in [Scheduling tasks](./scheduling-tasks.md) section. From a5c7546ae7545dbfab8eb020ba346eda6c867be7 Mon Sep 17 00:00:00 2001 From: Adi Berkowitz Date: Wed, 24 Dec 2025 13:57:30 -0500 Subject: [PATCH 4/4] Update cli.md --- docs/guide/cli.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guide/cli.md b/docs/guide/cli.md index dae3b6bb..66c8510d 100644 --- a/docs/guide/cli.md +++ b/docs/guide/cli.md @@ -150,7 +150,7 @@ The number of signals before a hard kill can be configured with the `--hardkill- * `--shutdown-timeout` - maximum amount of time for graceful broker's shutdown in seconds (default 5). * `--wait-tasks-timeout` - if cannot read new messages from the broker or maximum number of tasks is reached, worker will wait for all current tasks to finish. This parameter sets the maximum amount of time to wait until shutdown. * `--hardkill-count` - Number of termination signals to the main process before performing a hardkill. -* + ## Scheduler Scheduler is used to schedule tasks as described in [Scheduling tasks](./scheduling-tasks.md) section.