From 8212945bce5db23296ca4b3090ef02d6f107a456 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 17:19:41 -0700 Subject: [PATCH 01/28] Add gRPC resiliency design spec Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../2026-04-23-grpc-resiliency-design.md | 321 ++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md diff --git a/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md b/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md new file mode 100644 index 0000000..a2710e2 --- /dev/null +++ b/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md @@ -0,0 +1,321 @@ +# gRPC connection resiliency design + +## Problem statement + +`durabletask-python` already has basic gRPC retry policy configuration, +keepalive channel settings, and worker-side reconnect logic. It does not yet +have the stronger connection-healing behavior added in +`durabletask-dotnet` PR 708: + +- worker-side silent-disconnect detection for long-lived work-item streams +- consistent transport-failure classification +- client-side channel recreation after repeated transport failures +- shared backoff and threshold logic across connection-owning components + +The current gap shows up most clearly when a channel becomes stale or +half-open. The worker may continue retrying around a poisoned stream without a +clear distinction between graceful close and silent disconnect, and clients may +keep reusing a bad channel until the application recreates the client. + +## Goals + +- Detect and heal stale or silently disconnected gRPC connections in the worker + and in sync and async clients. +- Enable the new behavior by default with conservative values and explicit + override and disable knobs. +- Preserve existing protocol behavior and support for caller-supplied channels. +- Keep low-level gRPC channel options separate from SDK-managed resiliency + policy. +- Add focused regression tests for failure classification, backoff, and channel + recreation. + +## Non-goals + +- Redesign the public orchestration APIs or the sidecar protocol. +- Add general channel pooling or multi-endpoint load-balancing support. +- Automatically recreate caller-supplied channels in this iteration. +- Expand every possible raw gRPC channel knob as part of this work. + +## Proposed public API + +Add two new option dataclasses in `durabletask.grpc_options`. + +### `GrpcWorkerResiliencyOptions` + +Used by `TaskHubGrpcWorker` and Azure Managed worker wrappers. + +| Field | Default | Meaning | +| --- | --- | --- | +| `hello_timeout_seconds` | `30.0` | Deadline for the initial `Hello` handshake on a fresh connection. | +| `silent_disconnect_timeout_seconds` | `120.0` | Maximum idle period on the `GetWorkItems` stream before the worker treats the connection as stale. A value `<= 0` disables silent-disconnect detection. | +| `channel_recreate_failure_threshold` | `5` | Number of consecutive transport-shaped failures before the worker recreates an SDK-owned channel. A value `<= 0` disables recreation. | +| `reconnect_backoff_base_seconds` | `1.0` | Base delay for reconnect backoff. | +| `reconnect_backoff_cap_seconds` | `30.0` | Maximum reconnect delay. | + +### `GrpcClientResiliencyOptions` + +Used by `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and Azure Managed client +wrappers. + +| Field | Default | Meaning | +| --- | --- | --- | +| `channel_recreate_failure_threshold` | `5` | Number of consecutive transport-shaped unary RPC failures before recreating an SDK-owned channel. A value `<= 0` disables recreation. | +| `min_recreate_interval_seconds` | `30.0` | Minimum interval between channel recreation attempts. | + +### Constructor changes + +Add a new optional `resiliency_options` parameter to these constructors: + +- `TaskHubGrpcWorker` +- `TaskHubGrpcClient` +- `AsyncTaskHubGrpcClient` +- `DurableTaskSchedulerWorker` +- `DurableTaskSchedulerClient` +- `AsyncDurableTaskSchedulerClient` + +If the parameter is omitted, the SDK uses the defaults above. This keeps the +new behavior enabled by default while still allowing targeted disablement. + +`GrpcChannelOptions` remains the place for raw gRPC transport settings such as +keepalive and retry service config. Resiliency policy stays separate because it +controls SDK behavior, not just channel construction. + +## Runtime design + +### Shared internal helpers + +Add a small internal module dedicated to resiliency primitives. It should stay +transport-focused and reusable by the worker and clients. + +Planned responsibilities: + +- full-jitter exponential backoff calculation +- transport-failure classification helpers +- consecutive-failure tracking with reset semantics +- small immutable state objects where atomic swap is needed + +The worker and client should share the same definition of +"transport-shaped failure" instead of maintaining separate ad hoc rules. + +### Worker behavior + +The worker keeps its current high-level reconnect loop but replaces the +connection-health logic with clearer internal pieces. + +#### Fresh connection establishment + +When the worker creates an SDK-owned channel, it: + +1. builds the channel and stub as it does today +2. sends `Hello` with `hello_timeout_seconds` +3. treats `UNAVAILABLE` and `DEADLINE_EXCEEDED` on that handshake as transport + failures + +Successful `Hello` resets the worker reconnect attempt counter. + +#### Stream monitoring + +Wrap the `GetWorkItems` stream in an internal monitor that tracks two things: + +- whether any message has ever been observed on the stream +- whether the stream has remained idle longer than + `silent_disconnect_timeout_seconds` + +The monitor reports one of these outcomes: + +- `shutdown`: worker shutdown was requested +- `message_received`: at least one message arrived and normal processing + continues +- `graceful_close_before_first_message`: peer closed the stream before the + worker observed any message +- `graceful_close_after_message`: peer closed the stream after at least one + message was observed +- `silent_disconnect`: the stream remained idle past the configured timeout + +The outer worker loop uses those outcomes as follows: + +- `message_received`: reset health counters +- `graceful_close_before_first_message`: count as channel poison +- `graceful_close_after_message`: reconnect immediately without poisoning the + channel +- `silent_disconnect`: count as channel poison +- `shutdown`: exit cleanly + +This keeps rolling upgrades and normal peer-driven reconnects from being +treated the same as a stale half-open stream. + +#### Failure counting and recreation + +The worker increments the consecutive-failure counter only for +transport-shaped failures: + +- `UNAVAILABLE` +- `Hello` `DEADLINE_EXCEEDED` +- explicit silent-disconnect timeout +- graceful stream close before the first message + +It does not increment the counter for errors that channel recreation is +unlikely to fix, such as: + +- `UNAUTHENTICATED` +- `NOT_FOUND` +- orchestration or activity execution failures + +When the threshold is reached and the worker owns the channel, it recreates the +channel and stub. When the worker does not own the channel, it keeps retrying +the existing transport and logs that the channel could not be recreated. + +### Client behavior + +Both sync and async clients route unary RPCs through a small internal invoker +helper instead of calling generated stub methods directly. + +The helper: + +- invokes the target unary RPC +- classifies the outcome +- updates a shared failure counter +- schedules channel recreation when the threshold is crossed + +#### Counted failures + +Count these failures toward the client recreation threshold: + +- `UNAVAILABLE` +- `DEADLINE_EXCEEDED` for ordinary unary calls + +Do not count deadline failures for long-poll methods because those calls are +expected to wait: + +- `wait_for_orchestration_start` +- `wait_for_orchestration_completion` +- async variants of those methods + +Successful replies and application-level RPC errors reset the failure counter, +because they prove the underlying transport is still usable. + +#### Channel recreation mechanics + +When the threshold is reached and the client owns the channel: + +1. enforce `min_recreate_interval_seconds` +2. allow only one recreation in flight at a time +3. build a fresh channel and stub with the existing host, interceptors, secure + channel flag, and `GrpcChannelOptions` +4. atomically swap the active channel and stub +5. retire the previous channel after a short grace period + +The failing RPC still fails normally. The recreated channel benefits later RPCs. + +If the caller supplied the channel, the client still tracks and logs transport +failures but does not attempt replacement. + +### Retiring replaced channels + +Closing the old channel immediately after a successful swap risks interrupting +in-flight work that captured the old stub before the swap. To avoid that, the +SDK keeps replaced SDK-owned channels alive for a short grace period and then +closes them. + +The implementation can use a small internal scheduler that is appropriate for +the transport: + +- sync clients and the worker: daemon timer or background thread +- async clients: background task plus `asyncio.sleep` + +All retired channels are also closed during final client or worker shutdown. + +## File-level implementation plan + +### `durabletask/grpc_options.py` + +- add `GrpcWorkerResiliencyOptions` +- add `GrpcClientResiliencyOptions` +- add validation for positive durations when enabled + +### `durabletask/internal/grpc_resiliency.py` + +Add shared internals for: + +- backoff calculation +- failure classification +- failure-threshold tracking +- small helper types used by worker and client code + +### `durabletask/worker.py` + +- accept `resiliency_options` +- replace the current ad hoc reconnect bookkeeping with the shared helpers +- add hello deadline handling +- add stream-outcome monitoring +- recreate SDK-owned channels when the threshold is crossed + +### `durabletask/client.py` + +- accept `resiliency_options` in sync and async clients +- centralize unary RPC invocation through internal helpers +- add single-flight channel recreation and cooldown logic +- retain current ownership semantics for caller-supplied channels + +### Azure Managed wrappers + +Thread the new `resiliency_options` parameter through: + +- `DurableTaskSchedulerWorker` +- `DurableTaskSchedulerClient` +- `AsyncDurableTaskSchedulerClient` + +No Azure-specific recreation behavior is required in this iteration because the +wrappers already build SDK-owned channels through the base client and worker +constructors. + +## Testing strategy + +Add focused unit tests for the new behavior. + +### Options and helper tests + +- new resiliency option validation +- full-jitter backoff bounds and cap behavior +- failure counter reset and threshold logic +- transport-failure classification rules + +### Worker tests + +- hello deadline failure counts toward recreation +- silent-disconnect timeout is detected and classified +- graceful close before the first message poisons the channel +- graceful close after a message triggers reconnect without poisoning +- user-supplied channels are not recreated + +### Client tests + +- repeated `UNAVAILABLE` failures trigger recreation for SDK-owned channels +- long-poll `DEADLINE_EXCEEDED` does not count toward recreation +- application-level RPC errors reset the counter +- recreation is single-flight and cooldown-limited +- replaced channels are closed after the grace period +- caller-supplied channels are observed but not replaced + +### Regression coverage + +Existing client and worker tests should continue to pass without requiring users +to opt into the new behavior. + +## Compatibility and rollout + +- The change is backward compatible because all new constructor parameters are + optional. +- The new behavior is enabled by default for SDK-owned channels only. +- Caller-supplied channels preserve existing ownership and lifecycle behavior. +- No protocol changes are required between the Python SDK and the sidecar. +- The changelog should describe the new automatic healing of stale gRPC worker + and client connections and the new resiliency option types. + +## Decision summary + +Implement parity-inspired connection healing from `durabletask-dotnet` PR 708 +by adding explicit worker stream monitoring, shared failure classification, and +client-side channel recreation for SDK-owned channels. Keep raw gRPC channel +configuration separate from SDK resiliency policy and leave broader channel +pooling and user-supplied channel recreation out of this iteration. From 8b25b56fda9e75b9ccb13c10eb9f9ddd3af41bc4 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 17:28:15 -0700 Subject: [PATCH 02/28] Ignore local worktrees Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 1bccc90..d789df9 100644 --- a/.gitignore +++ b/.gitignore @@ -131,5 +131,6 @@ dmypy.json # IDEs .idea +.worktrees/ -coverage.lcov \ No newline at end of file +coverage.lcov From fd07d1ab355cda798c3a9e0015828e97fc319b8d Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 17:28:23 -0700 Subject: [PATCH 03/28] Add gRPC resiliency implementation plan Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../plans/2026-04-23-grpc-resiliency.md | 852 ++++++++++++++++++ 1 file changed, 852 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-23-grpc-resiliency.md diff --git a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md new file mode 100644 index 0000000..36f17a0 --- /dev/null +++ b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md @@ -0,0 +1,852 @@ +# gRPC Resiliency Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Implement automatic healing of stale gRPC worker streams and client channels in `durabletask-python`, aligned with the behavior added in `durabletask-dotnet` PR 708. + +**Architecture:** Add explicit public resiliency option types plus a shared internal transport helper module, then wire those pieces into the worker loop and the sync and async clients. SDK-owned channels will be recreated after repeated transport failures, while caller-owned channels keep their existing ownership model and are only observed and logged. + +**Tech Stack:** Python 3.10+, grpc, grpc.aio, pytest, unittest.mock, flake8 + +--- + +## File map + +- `durabletask/grpc_options.py` - public resiliency option dataclasses and validation +- `durabletask/internal/grpc_resiliency.py` - shared backoff, failure tracking, and transport-failure classification helpers +- `durabletask/client.py` - sync and async client transport state, unary invocation helpers, and channel recreation logic +- `durabletask/worker.py` - hello timeout, stream-outcome classification, worker reconnect policy, and SDK-owned channel recreation +- `durabletask-azuremanaged/durabletask/azuremanaged/client.py` - pass `resiliency_options` through Azure Managed client wrappers +- `durabletask-azuremanaged/durabletask/azuremanaged/worker.py` - pass `resiliency_options` through Azure Managed worker wrapper +- `tests/durabletask/test_grpc_resiliency.py` - option validation and shared helper tests +- `tests/durabletask/test_worker_resiliency.py` - worker stream monitoring and reconnect behavior +- `tests/durabletask/test_client.py` - sync and async client constructor and channel recreation tests +- `tests/durabletask-azuremanaged/test_grpc_resiliency.py` - wrapper pass-through tests for the new option surfaces +- `CHANGELOG.md` - user-facing changelog entry for the core SDK +- `durabletask-azuremanaged/CHANGELOG.md` - user-facing changelog entry for Azure Managed wrappers + +### Task 1: Add public resiliency option types + +**Files:** +- Modify: `durabletask/grpc_options.py` +- Create: `tests/durabletask/test_grpc_resiliency.py` + +- [ ] **Step 1: Write the failing option tests** + +```python +import pytest + +from durabletask.grpc_options import ( + GrpcClientResiliencyOptions, + GrpcWorkerResiliencyOptions, +) + + +def test_worker_resiliency_defaults_are_enabled(): + options = GrpcWorkerResiliencyOptions() + assert options.hello_timeout_seconds == 30.0 + assert options.silent_disconnect_timeout_seconds == 120.0 + assert options.channel_recreate_failure_threshold == 5 + assert options.reconnect_backoff_base_seconds == 1.0 + assert options.reconnect_backoff_cap_seconds == 30.0 + + +def test_worker_resiliency_allows_disabling_timeout_and_threshold(): + options = GrpcWorkerResiliencyOptions( + silent_disconnect_timeout_seconds=0.0, + channel_recreate_failure_threshold=0, + ) + assert options.silent_disconnect_timeout_seconds == 0.0 + assert options.channel_recreate_failure_threshold == 0 + + +def test_worker_resiliency_rejects_invalid_durations(): + with pytest.raises(ValueError, match="hello_timeout_seconds must be > 0"): + GrpcWorkerResiliencyOptions(hello_timeout_seconds=0.0) + with pytest.raises(ValueError, match="reconnect_backoff_cap_seconds must be >= reconnect_backoff_base_seconds"): + GrpcWorkerResiliencyOptions( + reconnect_backoff_base_seconds=5.0, + reconnect_backoff_cap_seconds=1.0, + ) + + +def test_client_resiliency_defaults_are_enabled(): + options = GrpcClientResiliencyOptions() + assert options.channel_recreate_failure_threshold == 5 + assert options.min_recreate_interval_seconds == 30.0 + + +def test_client_resiliency_rejects_negative_cooldown(): + with pytest.raises(ValueError, match="min_recreate_interval_seconds must be >= 0"): + GrpcClientResiliencyOptions(min_recreate_interval_seconds=-1.0) +``` + +- [ ] **Step 2: Run the test to verify it fails** + +Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` + +Expected: FAIL with `ImportError` or `AttributeError` because the new option classes do not exist yet. + +- [ ] **Step 3: Write the minimal implementation** + +```python +from dataclasses import dataclass, field +from typing import Any, Optional + + +@dataclass +class GrpcWorkerResiliencyOptions: + hello_timeout_seconds: float = 30.0 + silent_disconnect_timeout_seconds: float = 120.0 + channel_recreate_failure_threshold: int = 5 + reconnect_backoff_base_seconds: float = 1.0 + reconnect_backoff_cap_seconds: float = 30.0 + + def __post_init__(self) -> None: + if self.hello_timeout_seconds <= 0: + raise ValueError("hello_timeout_seconds must be > 0") + if self.silent_disconnect_timeout_seconds < 0: + raise ValueError("silent_disconnect_timeout_seconds must be >= 0") + if self.channel_recreate_failure_threshold < 0: + raise ValueError("channel_recreate_failure_threshold must be >= 0") + if self.reconnect_backoff_base_seconds <= 0: + raise ValueError("reconnect_backoff_base_seconds must be > 0") + if self.reconnect_backoff_cap_seconds <= 0: + raise ValueError("reconnect_backoff_cap_seconds must be > 0") + if self.reconnect_backoff_cap_seconds < self.reconnect_backoff_base_seconds: + raise ValueError( + "reconnect_backoff_cap_seconds must be >= reconnect_backoff_base_seconds" + ) + + +@dataclass +class GrpcClientResiliencyOptions: + channel_recreate_failure_threshold: int = 5 + min_recreate_interval_seconds: float = 30.0 + + def __post_init__(self) -> None: + if self.channel_recreate_failure_threshold < 0: + raise ValueError("channel_recreate_failure_threshold must be >= 0") + if self.min_recreate_interval_seconds < 0: + raise ValueError("min_recreate_interval_seconds must be >= 0") +``` + +- [ ] **Step 4: Run the tests to verify they pass** + +Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` + +Expected: PASS for the new option validation tests. + +- [ ] **Step 5: Commit** + +```bash +git add durabletask/grpc_options.py tests/durabletask/test_grpc_resiliency.py +git commit -m "Add gRPC resiliency option types" +``` + +### Task 2: Thread resiliency options through constructors + +**Files:** +- Modify: `durabletask/client.py` +- Modify: `durabletask/worker.py` +- Modify: `durabletask-azuremanaged/durabletask/azuremanaged/client.py` +- Modify: `durabletask-azuremanaged/durabletask/azuremanaged/worker.py` +- Modify: `tests/durabletask/test_client.py` +- Create: `tests/durabletask-azuremanaged/test_grpc_resiliency.py` + +- [ ] **Step 1: Write the failing constructor and wrapper tests** + +```python +from unittest.mock import MagicMock, patch + +from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient +from durabletask.grpc_options import ( + GrpcClientResiliencyOptions, + GrpcWorkerResiliencyOptions, +) +from durabletask.worker import TaskHubGrpcWorker +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + +def test_client_stores_resiliency_options_for_recreation(): + resiliency = GrpcClientResiliencyOptions(channel_recreate_failure_threshold=7) + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() + ): + client = TaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=resiliency, + ) + assert client._resiliency_options is resiliency + assert client._host_address == "localhost:4001" + + +def test_async_client_stores_resolved_transport_inputs(): + resiliency = GrpcClientResiliencyOptions() + with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() + ): + client = AsyncTaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=resiliency, + ) + assert client._resiliency_options is resiliency + assert client._host_address == "localhost:4001" + + +def test_worker_stores_resiliency_options(): + resiliency = GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=9) + worker = TaskHubGrpcWorker(resiliency_options=resiliency) + assert worker._resiliency_options is resiliency + + +def test_dts_client_passes_resiliency_options_to_base_client(): + resiliency = GrpcClientResiliencyOptions() + with patch("durabletask.azuremanaged.client.TaskHubGrpcClient.__init__", return_value=None) as mock_init: + DurableTaskSchedulerClient( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency + + +def test_dts_worker_passes_resiliency_options_to_base_worker(): + resiliency = GrpcWorkerResiliencyOptions() + with patch("durabletask.azuremanaged.worker.TaskHubGrpcWorker.__init__", return_value=None) as mock_init: + DurableTaskSchedulerWorker( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency +``` + +- [ ] **Step 2: Run the tests to verify they fail** + +Run: `python -m pytest tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v` + +Expected: FAIL because the constructors do not accept `resiliency_options` yet and do not retain enough transport state for later recreation. + +- [ ] **Step 3: Write the minimal implementation** + +```python +self._host_address = host_address if host_address else shared.get_default_host_address() +self._secure_channel = secure_channel +self._channel_options = channel_options +self._resiliency_options = ( + resiliency_options if resiliency_options is not None else GrpcClientResiliencyOptions() +) +resolved_interceptors = ( + prepare_sync_interceptors(metadata, interceptors) if channel is None else interceptors +) +self._interceptors = list(resolved_interceptors) if resolved_interceptors is not None else None + +self._resiliency_options = ( + resiliency_options if resiliency_options is not None else GrpcWorkerResiliencyOptions() +) + +super().__init__( + host_address=host_address, + channel=channel, + secure_channel=secure_channel, + metadata=None, + log_handler=log_handler, + log_formatter=log_formatter, + interceptors=resolved_interceptors, + channel_options=channel_options, + resiliency_options=resiliency_options, + default_version=default_version, + payload_store=payload_store, +) + +super().__init__( + host_address=host_address, + channel=channel, + secure_channel=secure_channel, + metadata=None, + log_handler=log_handler, + log_formatter=log_formatter, + interceptors=resolved_interceptors, + channel_options=channel_options, + resiliency_options=resiliency_options, + concurrency_options=concurrency_options, + maximum_timer_interval=None, + payload_store=payload_store, +) +``` + +- [ ] **Step 4: Run the tests to verify they pass** + +Run: `python -m pytest tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v` + +Expected: PASS for the new constructor and wrapper pass-through tests. + +- [ ] **Step 5: Commit** + +```bash +git add durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py +git commit -m "Thread gRPC resiliency options through constructors" +``` + +### Task 3: Add shared internal resiliency helpers + +**Files:** +- Create: `durabletask/internal/grpc_resiliency.py` +- Modify: `tests/durabletask/test_grpc_resiliency.py` + +- [ ] **Step 1: Write the failing helper tests** + +```python +import grpc +import pytest + +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + get_full_jitter_delay_seconds, + is_client_transport_failure, + is_worker_transport_failure, +) + + +def test_full_jitter_delay_is_capped(monkeypatch): + monkeypatch.setattr("durabletask.internal.grpc_resiliency.random.random", lambda: 1.0) + delay = get_full_jitter_delay_seconds(10, base_seconds=1.0, cap_seconds=30.0) + assert delay == 30.0 + + +def test_failure_tracker_trips_at_threshold(): + tracker = FailureTracker(threshold=3) + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.record_failure() is True + tracker.record_success() + assert tracker.consecutive_failures == 0 + + +def test_client_transport_failure_ignores_long_poll_deadlines(): + assert is_client_transport_failure("WaitForInstanceStart", grpc.StatusCode.DEADLINE_EXCEEDED) is False + assert is_client_transport_failure("StartInstance", grpc.StatusCode.DEADLINE_EXCEEDED) is True + assert is_client_transport_failure("GetInstance", grpc.StatusCode.UNAVAILABLE) is True + + +def test_worker_transport_failure_filters_application_errors(): + assert is_worker_transport_failure(grpc.StatusCode.UNAVAILABLE) is True + assert is_worker_transport_failure(grpc.StatusCode.DEADLINE_EXCEEDED) is True + assert is_worker_transport_failure(grpc.StatusCode.UNAUTHENTICATED) is False + assert is_worker_transport_failure(grpc.StatusCode.NOT_FOUND) is False +``` + +- [ ] **Step 2: Run the tests to verify they fail** + +Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -k "jitter or tracker or transport_failure" -v` + +Expected: FAIL because the shared helper module and helper functions do not exist yet. + +- [ ] **Step 3: Write the minimal implementation** + +```python +import random +from dataclasses import dataclass + +import grpc + + +LONG_POLL_METHODS = {"WaitForInstanceStart", "WaitForInstanceCompletion"} + + +def get_full_jitter_delay_seconds( + attempt: int, + *, + base_seconds: float, + cap_seconds: float, +) -> float: + capped_attempt = min(attempt, 30) + upper_bound = min(cap_seconds, base_seconds * (2 ** capped_attempt)) + return random.random() * upper_bound + + +@dataclass +class FailureTracker: + threshold: int + consecutive_failures: int = 0 + + def record_failure(self) -> bool: + if self.threshold <= 0: + return False + self.consecutive_failures += 1 + return self.consecutive_failures >= self.threshold + + def record_success(self) -> None: + self.consecutive_failures = 0 + + +def is_client_transport_failure(method_name: str, status_code: grpc.StatusCode) -> bool: + if status_code == grpc.StatusCode.UNAVAILABLE: + return True + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + return method_name not in LONG_POLL_METHODS + return False + + +def is_worker_transport_failure(status_code: grpc.StatusCode) -> bool: + return status_code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + } +``` + +- [ ] **Step 4: Run the tests to verify they pass** + +Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` + +Expected: PASS for the helper and option tests together. + +- [ ] **Step 5: Commit** + +```bash +git add durabletask/internal/grpc_resiliency.py tests/durabletask/test_grpc_resiliency.py +git commit -m "Add shared gRPC resiliency helpers" +``` + +### Task 4: Harden the worker stream lifecycle + +**Files:** +- Modify: `durabletask/worker.py` +- Create: `tests/durabletask/test_worker_resiliency.py` + +- [ ] **Step 1: Write the failing worker resiliency tests** + +```python +import grpc +from unittest.mock import MagicMock + +from durabletask.grpc_options import GrpcWorkerResiliencyOptions +from durabletask.worker import TaskHubGrpcWorker, _WorkItemStreamOutcome + + +def test_worker_classifies_graceful_close_before_first_message(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(silent_disconnect_timeout_seconds=5.0) + ) + outcome = worker._classify_stream_outcome( + saw_message=False, + timed_out=False, + ) + assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE + + +def test_worker_classifies_graceful_close_after_message(): + worker = TaskHubGrpcWorker() + outcome = worker._classify_stream_outcome( + saw_message=True, + timed_out=False, + ) + assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE + + +def test_worker_counts_only_transport_failures_for_recreation(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + assert worker._should_count_worker_failure(grpc.StatusCode.UNAVAILABLE) is True + assert worker._should_count_worker_failure(grpc.StatusCode.UNAUTHENTICATED) is False + + +def test_worker_does_not_recreate_caller_owned_channel(): + worker = TaskHubGrpcWorker(channel=MagicMock()) + assert worker._can_recreate_channel() is False +``` + +- [ ] **Step 2: Run the tests to verify they fail** + +Run: `python -m pytest tests/durabletask/test_worker_resiliency.py -v` + +Expected: FAIL because the worker does not expose explicit stream-outcome helpers yet and still uses ad hoc reconnect bookkeeping. + +- [ ] **Step 3: Write the minimal implementation** + +```python +class _WorkItemStreamOutcome(Enum): + SHUTDOWN = "shutdown" + GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE = "graceful_close_before_first_message" + GRACEFUL_CLOSE_AFTER_MESSAGE = "graceful_close_after_message" + SILENT_DISCONNECT = "silent_disconnect" + + +def _classify_stream_outcome(self, *, saw_message: bool, timed_out: bool) -> _WorkItemStreamOutcome: + if timed_out: + return _WorkItemStreamOutcome.SILENT_DISCONNECT + if saw_message: + return _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE + return _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE + + +def _should_count_worker_failure(self, status_code: grpc.StatusCode) -> bool: + return is_worker_transport_failure(status_code) + + +def _can_recreate_channel(self) -> bool: + return self._channel is None + + +hello_timeout = self._resiliency_options.hello_timeout_seconds +current_stub.Hello(empty_pb2.Empty(), timeout=hello_timeout) + +queue_timeout = self._resiliency_options.silent_disconnect_timeout_seconds or None +work_item = await asyncio.wait_for( + loop.run_in_executor(None, work_item_queue.get), + timeout=queue_timeout, +) + +delay = get_full_jitter_delay_seconds( + conn_retry_count, + base_seconds=self._resiliency_options.reconnect_backoff_base_seconds, + cap_seconds=self._resiliency_options.reconnect_backoff_cap_seconds, +) + +if work_item.HasField("healthPing"): + failure_tracker.record_success() + continue +``` + +- [ ] **Step 4: Run the worker tests** + +Run: `python -m pytest tests/durabletask/test_worker_resiliency.py -v` + +Expected: PASS for the worker classification and ownership tests. + +- [ ] **Step 5: Commit** + +```bash +git add durabletask/worker.py tests/durabletask/test_worker_resiliency.py +git commit -m "Harden worker gRPC stream reconnect behavior" +``` + +### Task 5: Add sync client channel recreation + +**Files:** +- Modify: `durabletask/client.py` +- Modify: `tests/durabletask/test_client.py` + +- [ ] **Step 1: Write the failing sync client recreation tests** + +```python +import grpc +import pytest +from unittest.mock import MagicMock, patch + +from durabletask.client import TaskHubGrpcClient +from durabletask.grpc_options import GrpcClientResiliencyOptions + + +def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + first_stub = MagicMock() + first_stub.GetInstance.side_effect = grpc.RpcError() + second_stub = MagicMock() + second_stub.GetInstance.return_value = MagicMock(exists=False) + + rpc_error = MagicMock(spec=grpc.RpcError) + rpc_error.code.return_value = grpc.StatusCode.UNAVAILABLE + first_stub.GetInstance.side_effect = rpc_error + + with patch("durabletask.client.shared.get_grpc_channel", side_effect=[first_channel, second_channel]), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ): + client = TaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ), + ) + with pytest.raises(grpc.RpcError): + client.get_orchestration_state("abc") + client.get_orchestration_state("abc") + + assert client._channel is second_channel + + +def test_sync_client_does_not_count_long_poll_deadline(): + rpc_error = MagicMock(spec=grpc.RpcError) + rpc_error.code.return_value = grpc.StatusCode.DEADLINE_EXCEEDED + stub = MagicMock() + stub.WaitForInstanceStart.side_effect = rpc_error + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1) + ) + with pytest.raises(TimeoutError): + client.wait_for_orchestration_start("abc") + assert client._client_failure_tracker.consecutive_failures == 0 +``` + +- [ ] **Step 2: Run the tests to verify they fail** + +Run: `python -m pytest tests/durabletask/test_client.py -k "recreates_sdk_owned_channel or long_poll_deadline" -v` + +Expected: FAIL because client calls still go directly through the stub and the client has no failure tracker or channel recreation path. + +- [ ] **Step 3: Write the minimal implementation** + +```python +self._client_failure_tracker = FailureTracker( + self._resiliency_options.channel_recreate_failure_threshold +) +self._last_recreate_time = 0.0 +self._recreate_lock = threading.Lock() + +def _invoke_unary(self, method_name: str, request: Any, *, timeout: Optional[int] = None): + method = getattr(self._stub, method_name) + try: + if timeout is None: + response = method(request) + else: + response = method(request, timeout=timeout) + except grpc.RpcError as rpc_error: + if is_client_transport_failure(method_name, rpc_error.code()): + should_recreate = self._client_failure_tracker.record_failure() + if should_recreate: + self._maybe_recreate_channel() + else: + self._client_failure_tracker.record_success() + raise + else: + self._client_failure_tracker.record_success() + return response + +def _maybe_recreate_channel(self) -> None: + if not self._owns_channel: + return + with self._recreate_lock: + now = time.monotonic() + if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: + return + old_channel = self._channel + self._channel = shared.get_grpc_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._last_recreate_time = now + self._client_failure_tracker.record_success() + threading.Timer(30.0, old_channel.close).start() +``` + +- [ ] **Step 4: Run the tests to verify they pass** + +Run: `python -m pytest tests/durabletask/test_client.py -k "recreates_sdk_owned_channel or long_poll_deadline" -v` + +Expected: PASS for both new sync client tests and no regressions in the existing client construction tests. + +- [ ] **Step 5: Commit** + +```bash +git add durabletask/client.py tests/durabletask/test_client.py +git commit -m "Add sync client gRPC channel recreation" +``` + +### Task 6: Add async client channel recreation + +**Files:** +- Modify: `durabletask/client.py` +- Modify: `tests/durabletask/test_client.py` + +- [ ] **Step 1: Write the failing async client recreation tests** + +```python +import grpc +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from durabletask.client import AsyncTaskHubGrpcClient +from durabletask.grpc_options import GrpcClientResiliencyOptions + + +@pytest.mark.asyncio +async def test_async_client_recreates_sdk_owned_channel_after_unavailable(): + rpc_error = MagicMock(spec=grpc.aio.AioRpcError) + rpc_error.code.return_value = grpc.StatusCode.UNAVAILABLE + + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=rpc_error) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + + with patch("durabletask.client.shared.get_async_grpc_channel", side_effect=[MagicMock(), MagicMock()]), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ): + client = AsyncTaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ), + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + await client.get_orchestration_state("abc") + + +@pytest.mark.asyncio +async def test_async_client_does_not_count_wait_for_orchestration_deadline(): + rpc_error = MagicMock(spec=grpc.aio.AioRpcError) + rpc_error.code.return_value = grpc.StatusCode.DEADLINE_EXCEEDED + stub = MagicMock() + stub.WaitForInstanceCompletion = AsyncMock(side_effect=rpc_error) + + with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1) + ) + with pytest.raises(TimeoutError): + await client.wait_for_orchestration_completion("abc") + assert client._client_failure_tracker.consecutive_failures == 0 +``` + +- [ ] **Step 2: Run the tests to verify they fail** + +Run: `python -m pytest tests/durabletask/test_client.py -k "async_client_recreates_sdk_owned_channel or async_client_does_not_count" -v` + +Expected: FAIL because the async client still awaits stub methods directly and has no async-safe recreation path. + +- [ ] **Step 3: Write the minimal implementation** + +```python +self._client_failure_tracker = FailureTracker( + self._resiliency_options.channel_recreate_failure_threshold +) +self._recreate_lock = asyncio.Lock() +self._last_recreate_time = 0.0 + +async def _invoke_unary(self, method_name: str, request: Any, *, timeout: Optional[int] = None): + method = getattr(self._stub, method_name) + try: + if timeout is None: + response = await method(request) + else: + response = await method(request, timeout=timeout) + except grpc.aio.AioRpcError as rpc_error: + if is_client_transport_failure(method_name, rpc_error.code()): + should_recreate = self._client_failure_tracker.record_failure() + if should_recreate: + await self._maybe_recreate_channel() + else: + self._client_failure_tracker.record_success() + raise + else: + self._client_failure_tracker.record_success() + return response + +async def _maybe_recreate_channel(self) -> None: + if not self._owns_channel: + return + async with self._recreate_lock: + now = time.monotonic() + if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: + return + old_channel = self._channel + self._channel = shared.get_async_grpc_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._last_recreate_time = now + self._client_failure_tracker.record_success() + asyncio.create_task(self._close_retired_channel(old_channel)) + + +async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None: + await asyncio.sleep(30.0) + await channel.close() +``` + +- [ ] **Step 4: Run the tests to verify they pass** + +Run: `python -m pytest tests/durabletask/test_client.py -k "async_client_recreates_sdk_owned_channel or async_client_does_not_count" -v` + +Expected: PASS for the async recreation tests and no regressions in the existing async client construction tests. + +- [ ] **Step 5: Commit** + +```bash +git add durabletask/client.py tests/durabletask/test_client.py +git commit -m "Add async client gRPC channel recreation" +``` + +### Task 7: Update changelogs and run final verification + +**Files:** +- Modify: `CHANGELOG.md` +- Modify: `durabletask-azuremanaged/CHANGELOG.md` +- Modify: `docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md` (only if the implementation changed the agreed design) +- Modify: `docs/superpowers/plans/2026-04-23-grpc-resiliency.md` (check off completed steps only after execution) + +- [ ] **Step 1: Add the changelog entries** + +```markdown +## Unreleased + +### Added + +- Added automatic gRPC channel healing for SDK-owned clients and workers, with new resiliency option types for tuning hello deadlines, silent-disconnect detection, recreate thresholds, and recreate cooldowns. +``` + +```markdown +## Unreleased + +### Added + +- Added pass-through support for the new gRPC resiliency option types on Azure Managed clients and workers. +``` + +- [ ] **Step 2: Run the focused tests** + +Run: + +```bash +python -m pytest tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v +``` + +Expected: PASS for all new and touched unit tests. + +- [ ] **Step 3: Run lint on the changed Python files** + +Run: + +```bash +python -m flake8 durabletask/grpc_options.py durabletask/internal/grpc_resiliency.py durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py +``` + +Expected: no output + +- [ ] **Step 4: Run the full test suite** + +Run: + +```bash +python -m pytest +``` + +Expected: PASS across the repository, including the existing orchestration and Azure Managed test suites. + +- [ ] **Step 5: Commit** + +```bash +git add CHANGELOG.md durabletask-azuremanaged/CHANGELOG.md durabletask/grpc_options.py durabletask/internal/grpc_resiliency.py durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py +git commit -m "Add gRPC connection resiliency" +``` From 774004a5c6665465b763c075aaa8fe5c084c06f5 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 18:02:52 -0700 Subject: [PATCH 04/28] Add gRPC resiliency option types Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 3 ++ durabletask/grpc_options.py | 41 ++++++++++++++++ tests/durabletask/test_grpc_resiliency.py | 60 +++++++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 tests/durabletask/test_grpc_resiliency.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f84e57..c5fc3cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ ADDED - Added `GrpcChannelOptions` and `GrpcRetryPolicyOptions` for configuring gRPC transport behavior, including message-size limits, keepalive settings, and channel-level retry policy service configuration. +- Added `GrpcWorkerResiliencyOptions` and `GrpcClientResiliencyOptions` for + configuring public gRPC reconnect, hello timeout, and channel recreation + thresholds. - Added optional `channel` and `channel_options` parameters to `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` to support pre-configured channel passthrough and low-level gRPC channel diff --git a/durabletask/grpc_options.py b/durabletask/grpc_options.py index 56f2236..f648104 100644 --- a/durabletask/grpc_options.py +++ b/durabletask/grpc_options.py @@ -100,3 +100,44 @@ def to_grpc_options(self) -> list[tuple[str, Any]]: options.append(("grpc.service_config", json.dumps(self.retry_policy.to_service_config()))) return options + + +@dataclass +class GrpcWorkerResiliencyOptions: + """Configuration for worker-side gRPC resiliency behavior.""" + + hello_timeout_seconds: float = 30.0 + silent_disconnect_timeout_seconds: float = 120.0 + channel_recreate_failure_threshold: int = 5 + reconnect_backoff_base_seconds: float = 1.0 + reconnect_backoff_cap_seconds: float = 30.0 + + def __post_init__(self) -> None: + if self.hello_timeout_seconds <= 0: + raise ValueError("hello_timeout_seconds must be > 0") + if self.silent_disconnect_timeout_seconds < 0: + raise ValueError("silent_disconnect_timeout_seconds must be >= 0") + if self.channel_recreate_failure_threshold < 0: + raise ValueError("channel_recreate_failure_threshold must be >= 0") + if self.reconnect_backoff_base_seconds <= 0: + raise ValueError("reconnect_backoff_base_seconds must be > 0") + if self.reconnect_backoff_cap_seconds <= 0: + raise ValueError("reconnect_backoff_cap_seconds must be > 0") + if self.reconnect_backoff_cap_seconds < self.reconnect_backoff_base_seconds: + raise ValueError( + "reconnect_backoff_cap_seconds must be >= reconnect_backoff_base_seconds" + ) + + +@dataclass +class GrpcClientResiliencyOptions: + """Configuration for client-side gRPC resiliency behavior.""" + + channel_recreate_failure_threshold: int = 5 + min_recreate_interval_seconds: float = 30.0 + + def __post_init__(self) -> None: + if self.channel_recreate_failure_threshold < 0: + raise ValueError("channel_recreate_failure_threshold must be >= 0") + if self.min_recreate_interval_seconds < 0: + raise ValueError("min_recreate_interval_seconds must be >= 0") diff --git a/tests/durabletask/test_grpc_resiliency.py b/tests/durabletask/test_grpc_resiliency.py new file mode 100644 index 0000000..35489f3 --- /dev/null +++ b/tests/durabletask/test_grpc_resiliency.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from durabletask.grpc_options import ( + GrpcClientResiliencyOptions, + GrpcWorkerResiliencyOptions, +) + + +def test_worker_resiliency_defaults_are_enabled(): + options = GrpcWorkerResiliencyOptions() + + assert options.hello_timeout_seconds == 30.0 + assert options.silent_disconnect_timeout_seconds == 120.0 + assert options.channel_recreate_failure_threshold == 5 + assert options.reconnect_backoff_base_seconds == 1.0 + assert options.reconnect_backoff_cap_seconds == 30.0 + + +def test_worker_resiliency_allows_disabling_timeout_and_threshold(): + options = GrpcWorkerResiliencyOptions( + silent_disconnect_timeout_seconds=0.0, + channel_recreate_failure_threshold=0, + ) + + assert options.silent_disconnect_timeout_seconds == 0.0 + assert options.channel_recreate_failure_threshold == 0 + + +def test_worker_resiliency_rejects_invalid_durations(): + with pytest.raises(ValueError, match="hello_timeout_seconds must be > 0"): + GrpcWorkerResiliencyOptions(hello_timeout_seconds=0.0) + + with pytest.raises( + ValueError, + match=( + "reconnect_backoff_cap_seconds must be >= " + "reconnect_backoff_base_seconds" + ), + ): + GrpcWorkerResiliencyOptions( + reconnect_backoff_base_seconds=5.0, + reconnect_backoff_cap_seconds=1.0, + ) + + +def test_client_resiliency_defaults_are_enabled(): + options = GrpcClientResiliencyOptions() + + assert options.channel_recreate_failure_threshold == 5 + assert options.min_recreate_interval_seconds == 30.0 + + +def test_client_resiliency_rejects_negative_cooldown(): + with pytest.raises( + ValueError, match="min_recreate_interval_seconds must be >= 0" + ): + GrpcClientResiliencyOptions(min_recreate_interval_seconds=-1.0) From 4e8d0bb6313d6f0cd00ce20c9fa4029a377e9971 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 18:12:57 -0700 Subject: [PATCH 05/28] Add grpc resiliency validation tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_grpc_resiliency.py | 65 ++++++++++++++++------- 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/tests/durabletask/test_grpc_resiliency.py b/tests/durabletask/test_grpc_resiliency.py index 35489f3..727a8c0 100644 --- a/tests/durabletask/test_grpc_resiliency.py +++ b/tests/durabletask/test_grpc_resiliency.py @@ -29,21 +29,39 @@ def test_worker_resiliency_allows_disabling_timeout_and_threshold(): assert options.channel_recreate_failure_threshold == 0 -def test_worker_resiliency_rejects_invalid_durations(): - with pytest.raises(ValueError, match="hello_timeout_seconds must be > 0"): - GrpcWorkerResiliencyOptions(hello_timeout_seconds=0.0) - - with pytest.raises( - ValueError, - match=( +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ({"hello_timeout_seconds": 0.0}, "hello_timeout_seconds must be > 0"), + ( + {"silent_disconnect_timeout_seconds": -1.0}, + "silent_disconnect_timeout_seconds must be >= 0", + ), + ( + {"channel_recreate_failure_threshold": -1}, + "channel_recreate_failure_threshold must be >= 0", + ), + ( + {"reconnect_backoff_base_seconds": 0.0}, + "reconnect_backoff_base_seconds must be > 0", + ), + ( + {"reconnect_backoff_cap_seconds": 0.0}, + "reconnect_backoff_cap_seconds must be > 0", + ), + ( + { + "reconnect_backoff_base_seconds": 5.0, + "reconnect_backoff_cap_seconds": 1.0, + }, "reconnect_backoff_cap_seconds must be >= " - "reconnect_backoff_base_seconds" + "reconnect_backoff_base_seconds", ), - ): - GrpcWorkerResiliencyOptions( - reconnect_backoff_base_seconds=5.0, - reconnect_backoff_cap_seconds=1.0, - ) + ], +) +def test_worker_resiliency_rejects_invalid_values(kwargs, message): + with pytest.raises(ValueError, match=message): + GrpcWorkerResiliencyOptions(**kwargs) def test_client_resiliency_defaults_are_enabled(): @@ -53,8 +71,19 @@ def test_client_resiliency_defaults_are_enabled(): assert options.min_recreate_interval_seconds == 30.0 -def test_client_resiliency_rejects_negative_cooldown(): - with pytest.raises( - ValueError, match="min_recreate_interval_seconds must be >= 0" - ): - GrpcClientResiliencyOptions(min_recreate_interval_seconds=-1.0) +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ( + {"channel_recreate_failure_threshold": -1}, + "channel_recreate_failure_threshold must be >= 0", + ), + ( + {"min_recreate_interval_seconds": -1.0}, + "min_recreate_interval_seconds must be >= 0", + ), + ], +) +def test_client_resiliency_rejects_invalid_values(kwargs, message): + with pytest.raises(ValueError, match=message): + GrpcClientResiliencyOptions(**kwargs) From c09def69a2a56243c4aca40a6f349ba129752e3a Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 18:41:29 -0700 Subject: [PATCH 06/28] Thread gRPC resiliency options through constructors Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 3 ++ durabletask-azuremanaged/CHANGELOG.md | 4 ++ .../durabletask/azuremanaged/client.py | 11 +++- .../durabletask/azuremanaged/worker.py | 9 +++- durabletask/client.py | 53 ++++++++++++++++--- durabletask/worker.py | 13 ++++- .../test_grpc_resiliency.py | 53 +++++++++++++++++++ tests/durabletask/test_client.py | 40 +++++++++++++- 8 files changed, 175 insertions(+), 11 deletions(-) create mode 100644 tests/durabletask-azuremanaged/test_grpc_resiliency.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c5fc3cb..b63ad05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ ADDED `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` to support pre-configured channel passthrough and low-level gRPC channel customization. +- Added optional `resiliency_options` parameters to `TaskHubGrpcClient`, + `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` so applications can pass + gRPC resiliency settings through constructor APIs. - Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. - Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` so local orchestration tests can retrieve history and page terminal instance IDs by completion window. diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 639d75f..7a0408f 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and `DurableTaskSchedulerWorker` to allow combining custom gRPC interceptors with DTS defaults and to support pre-configured/customized gRPC channels. +- Added optional `resiliency_options` parameters to + `DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and + `DurableTaskSchedulerWorker` so applications can pass gRPC resiliency + settings through their constructors. - Added `workerid` gRPC metadata on Durable Task Scheduler worker calls for improved worker identity and observability. - Improved sync access token refresh concurrency handling to avoid duplicate diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index ea30471..ed5dcc9 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -15,7 +15,10 @@ DTSDefaultClientInterceptorImpl, ) from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcClientResiliencyOptions, +) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore @@ -30,6 +33,7 @@ def __init__(self, *, secure_channel: bool = True, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, @@ -54,6 +58,7 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=resolved_interceptors, channel_options=channel_options, + resiliency_options=resiliency_options, default_version=default_version, payload_store=payload_store) @@ -74,6 +79,8 @@ class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient): If None, anonymous authentication will be used. secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). Defaults to True. + resiliency_options (Optional[GrpcClientResiliencyOptions], optional): Client-side + gRPC resiliency settings forwarded to the base async client. default_version (Optional[str], optional): Default version string for orchestrations. payload_store (Optional[PayloadStore], optional): A payload store for externalizing large payloads. If None, payloads are sent inline. @@ -104,6 +111,7 @@ def __init__(self, *, secure_channel: bool = True, interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, @@ -128,5 +136,6 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=resolved_interceptors, channel_options=channel_options, + resiliency_options=resiliency_options, default_version=default_version, payload_store=payload_store) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 7e9c0ef..6956ae2 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -13,7 +13,10 @@ from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ DTSDefaultClientInterceptorImpl -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcWorkerResiliencyOptions, +) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker @@ -34,6 +37,8 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): If None, anonymous authentication will be used. secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). Defaults to True. + resiliency_options (Optional[GrpcWorkerResiliencyOptions], optional): Worker-side + gRPC resiliency settings forwarded to the base worker. concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default concurrency settings will be used. @@ -74,6 +79,7 @@ def __init__(self, *, secure_channel: bool = True, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcWorkerResiliencyOptions] = None, concurrency_options: Optional[ConcurrencyOptions] = None, payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, @@ -101,6 +107,7 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=resolved_interceptors, channel_options=channel_options, + resiliency_options=resiliency_options, concurrency_options=concurrency_options, # DTS natively supports long timers so chunking is unnecessary maximum_timer_interval=None, diff --git a/durabletask/client.py b/durabletask/client.py index 0ef223b..fa3a924 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -14,7 +14,10 @@ import durabletask.history as history from durabletask.entities import EntityInstanceId from durabletask.entities.entity_metadata import EntityMetadata -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcClientResiliencyOptions, +) import durabletask.internal.helpers as helpers import durabletask.internal.history_helpers as history_helpers import durabletask.internal.orchestrator_service_pb2 as pb @@ -166,16 +169,34 @@ def __init__(self, *, secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None): self._owns_channel = channel is None + self._host_address = ( + host_address if host_address else shared.get_default_host_address() + ) + self._secure_channel = secure_channel + self._channel_options = channel_options + self._resiliency_options = ( + resiliency_options + if resiliency_options is not None + else GrpcClientResiliencyOptions() + ) + resolved_interceptors = ( + prepare_sync_interceptors(metadata, interceptors) if channel is None else interceptors + ) + self._interceptors = ( + list(resolved_interceptors) + if resolved_interceptors is not None + else None + ) if channel is None: - interceptors = prepare_sync_interceptors(metadata, interceptors) channel = shared.get_grpc_channel( - host_address=host_address, + host_address=self._host_address, secure_channel=secure_channel, - interceptors=interceptors, + interceptors=self._interceptors, channel_options=channel_options, ) self._channel = channel @@ -496,16 +517,34 @@ def __init__(self, *, secure_channel: bool = False, interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None): self._owns_channel = channel is None + self._host_address = ( + host_address if host_address else shared.get_default_host_address() + ) + self._secure_channel = secure_channel + self._channel_options = channel_options + self._resiliency_options = ( + resiliency_options + if resiliency_options is not None + else GrpcClientResiliencyOptions() + ) + resolved_interceptors = ( + prepare_async_interceptors(metadata, interceptors) if channel is None else interceptors + ) + self._interceptors = ( + list(resolved_interceptors) + if resolved_interceptors is not None + else None + ) if channel is None: - interceptors = prepare_async_interceptors(metadata, interceptors) channel = shared.get_async_grpc_channel( - host_address=host_address, + host_address=self._host_address, secure_channel=secure_channel, - interceptors=interceptors, + interceptors=self._interceptors, channel_options=channel_options, ) self._channel = channel diff --git a/durabletask/worker.py b/durabletask/worker.py index 670a387..9dc28b3 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -21,7 +21,10 @@ import grpc from google.protobuf import empty_pb2 -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcWorkerResiliencyOptions, +) from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim @@ -369,6 +372,8 @@ class TaskHubGrpcWorker: interceptors to apply to the channel. Defaults to None. channel_options (Optional[GrpcChannelOptions], optional): Extra low-level gRPC channel configuration including retry/service config options. + resiliency_options (Optional[GrpcWorkerResiliencyOptions], optional): Worker-side + gRPC resiliency settings retained for reconnect handling. concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default settings are used. @@ -436,6 +441,7 @@ def __init__( secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcWorkerResiliencyOptions] = None, concurrency_options: Optional[ConcurrencyOptions] = None, maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL, payload_store: Optional[PayloadStore] = None, @@ -451,6 +457,11 @@ def __init__( self._secure_channel = secure_channel self._payload_store = payload_store self._channel_options = channel_options + self._resiliency_options = ( + resiliency_options + if resiliency_options is not None + else GrpcWorkerResiliencyOptions() + ) # Use provided concurrency options or create default ones self._concurrency_options = ( diff --git a/tests/durabletask-azuremanaged/test_grpc_resiliency.py b/tests/durabletask-azuremanaged/test_grpc_resiliency.py new file mode 100644 index 0000000..8d0fa5d --- /dev/null +++ b/tests/durabletask-azuremanaged/test_grpc_resiliency.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from unittest.mock import patch + +from durabletask.azuremanaged.client import ( + AsyncDurableTaskSchedulerClient, + DurableTaskSchedulerClient, +) +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.grpc_options import ( + GrpcClientResiliencyOptions, + GrpcWorkerResiliencyOptions, +) + + +def test_dts_client_passes_resiliency_options_to_base_client(): + resiliency = GrpcClientResiliencyOptions() + with patch("durabletask.azuremanaged.client.TaskHubGrpcClient.__init__", return_value=None) as mock_init: + DurableTaskSchedulerClient( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency + + +def test_dts_worker_passes_resiliency_options_to_base_worker(): + resiliency = GrpcWorkerResiliencyOptions() + with patch("durabletask.azuremanaged.worker.TaskHubGrpcWorker.__init__", return_value=None) as mock_init: + DurableTaskSchedulerWorker( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency + + +def test_async_dts_client_passes_resiliency_options_to_base_client(): + resiliency = GrpcClientResiliencyOptions() + with patch( + "durabletask.azuremanaged.client.AsyncTaskHubGrpcClient.__init__", + return_value=None, + ) as mock_init: + AsyncDurableTaskSchedulerClient( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 9bb56ea..b831798 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -8,8 +8,14 @@ import durabletask.history as history import durabletask.internal.orchestrator_service_pb2 as pb from durabletask.client import AsyncTaskHubGrpcClient, OrchestrationStatus, TaskHubGrpcClient -from durabletask.grpc_options import GrpcChannelOptions, GrpcRetryPolicyOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcClientResiliencyOptions, + GrpcRetryPolicyOptions, + GrpcWorkerResiliencyOptions, +) from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore +from durabletask.worker import TaskHubGrpcWorker from durabletask.internal.grpc_interceptor import ( DefaultAsyncClientInterceptorImpl, @@ -290,6 +296,38 @@ def test_async_client_uses_provided_channel_directly(): mock_get_channel.assert_not_called() +def test_client_stores_resiliency_options_for_recreation(): + resiliency = GrpcClientResiliencyOptions(channel_recreate_failure_threshold=7) + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() + ): + client = TaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=resiliency, + ) + assert client._resiliency_options is resiliency + assert client._host_address == "localhost:4001" + + +def test_async_client_stores_resolved_transport_inputs(): + resiliency = GrpcClientResiliencyOptions() + with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() + ): + client = AsyncTaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=resiliency, + ) + assert client._resiliency_options is resiliency + assert client._host_address == "localhost:4001" + + +def test_worker_stores_resiliency_options(): + resiliency = GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=9) + worker = TaskHubGrpcWorker(resiliency_options=resiliency) + assert worker._resiliency_options is resiliency + + def test_get_orchestration_history_aggregates_chunks_and_deexternalizes_payloads(): store = FakePayloadStore() token = store.upload(b'history payload') From a6bc855f0d5f9cbec175dcf2da421554cecc7835 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 18:57:59 -0700 Subject: [PATCH 07/28] Strengthen retained client state tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_client.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index b831798..a4a1b5b 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -298,28 +298,44 @@ def test_async_client_uses_provided_channel_directly(): def test_client_stores_resiliency_options_for_recreation(): resiliency = GrpcClientResiliencyOptions(channel_recreate_failure_threshold=7) + channel_options = GrpcChannelOptions(max_receive_message_length=1234) + interceptors = [DefaultClientInterceptorImpl(METADATA)] with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() ): client = TaskHubGrpcClient( host_address="localhost:4001", + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, resiliency_options=resiliency, ) assert client._resiliency_options is resiliency assert client._host_address == "localhost:4001" + assert client._secure_channel is True + assert client._channel_options is channel_options + assert client._interceptors == interceptors def test_async_client_stores_resolved_transport_inputs(): resiliency = GrpcClientResiliencyOptions() + channel_options = GrpcChannelOptions(max_send_message_length=4321) + interceptors = [DefaultAsyncClientInterceptorImpl(METADATA)] with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() ): client = AsyncTaskHubGrpcClient( host_address="localhost:4001", + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, resiliency_options=resiliency, ) assert client._resiliency_options is resiliency assert client._host_address == "localhost:4001" + assert client._secure_channel is True + assert client._channel_options is channel_options + assert client._interceptors == interceptors def test_worker_stores_resiliency_options(): From a2ff52bfd9129999ac014bd8a23c3186c48bc39a Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 19:08:57 -0700 Subject: [PATCH 08/28] Add shared gRPC resiliency helpers Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- durabletask/internal/grpc_resiliency.py | 49 +++++++++++++++++ tests/durabletask/test_grpc_resiliency.py | 65 +++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 durabletask/internal/grpc_resiliency.py diff --git a/durabletask/internal/grpc_resiliency.py b/durabletask/internal/grpc_resiliency.py new file mode 100644 index 0000000..4845523 --- /dev/null +++ b/durabletask/internal/grpc_resiliency.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import random +from dataclasses import dataclass + +import grpc + +LONG_POLL_METHODS = {"WaitForInstanceStart", "WaitForInstanceCompletion"} + + +def get_full_jitter_delay_seconds( + attempt: int, + *, + base_seconds: float, + cap_seconds: float) -> float: + capped_attempt = min(attempt, 30) + upper_bound = min(cap_seconds, base_seconds * (2 ** capped_attempt)) + return random.random() * upper_bound + + +@dataclass +class FailureTracker: + threshold: int + consecutive_failures: int = 0 + + def record_failure(self) -> bool: + if self.threshold <= 0: + return False + self.consecutive_failures += 1 + return self.consecutive_failures >= self.threshold + + def record_success(self) -> None: + self.consecutive_failures = 0 + + +def is_client_transport_failure(method_name: str, status_code: grpc.StatusCode) -> bool: + if status_code == grpc.StatusCode.UNAVAILABLE: + return True + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + return method_name not in LONG_POLL_METHODS + return False + + +def is_worker_transport_failure(status_code: grpc.StatusCode) -> bool: + return status_code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + } diff --git a/tests/durabletask/test_grpc_resiliency.py b/tests/durabletask/test_grpc_resiliency.py index 727a8c0..f990be8 100644 --- a/tests/durabletask/test_grpc_resiliency.py +++ b/tests/durabletask/test_grpc_resiliency.py @@ -1,12 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import grpc import pytest from durabletask.grpc_options import ( GrpcClientResiliencyOptions, GrpcWorkerResiliencyOptions, ) +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + get_full_jitter_delay_seconds, + is_client_transport_failure, + is_worker_transport_failure, +) def test_worker_resiliency_defaults_are_enabled(): @@ -87,3 +94,61 @@ def test_client_resiliency_defaults_are_enabled(): def test_client_resiliency_rejects_invalid_values(kwargs, message): with pytest.raises(ValueError, match=message): GrpcClientResiliencyOptions(**kwargs) + + +def test_full_jitter_delay_is_capped(monkeypatch): + monkeypatch.setattr( + "durabletask.internal.grpc_resiliency.random.random", + lambda: 1.0, + ) + + delay = get_full_jitter_delay_seconds( + 10, + base_seconds=1.0, + cap_seconds=30.0, + ) + + assert delay == 30.0 + + +def test_failure_tracker_trips_at_threshold(): + tracker = FailureTracker(threshold=3) + + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.record_failure() is True + + tracker.record_success() + + assert tracker.consecutive_failures == 0 + + +def test_client_transport_failure_ignores_long_poll_deadlines(): + assert ( + is_client_transport_failure( + "WaitForInstanceStart", + grpc.StatusCode.DEADLINE_EXCEEDED, + ) + is False + ) + assert ( + is_client_transport_failure( + "StartInstance", + grpc.StatusCode.DEADLINE_EXCEEDED, + ) + is True + ) + assert ( + is_client_transport_failure( + "GetInstance", + grpc.StatusCode.UNAVAILABLE, + ) + is True + ) + + +def test_worker_transport_failure_filters_application_errors(): + assert is_worker_transport_failure(grpc.StatusCode.UNAVAILABLE) is True + assert is_worker_transport_failure(grpc.StatusCode.DEADLINE_EXCEEDED) is True + assert is_worker_transport_failure(grpc.StatusCode.UNAUTHENTICATED) is False + assert is_worker_transport_failure(grpc.StatusCode.NOT_FOUND) is False From c049899802ff4caac3ad8a2792dfacbbb8e28c14 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 19:14:06 -0700 Subject: [PATCH 09/28] Add completion long-poll resiliency test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_grpc_resiliency.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/durabletask/test_grpc_resiliency.py b/tests/durabletask/test_grpc_resiliency.py index f990be8..5374672 100644 --- a/tests/durabletask/test_grpc_resiliency.py +++ b/tests/durabletask/test_grpc_resiliency.py @@ -123,10 +123,17 @@ def test_failure_tracker_trips_at_threshold(): assert tracker.consecutive_failures == 0 -def test_client_transport_failure_ignores_long_poll_deadlines(): +@pytest.mark.parametrize( + "method_name", + [ + "WaitForInstanceStart", + "WaitForInstanceCompletion", + ], +) +def test_client_transport_failure_ignores_long_poll_deadlines(method_name): assert ( is_client_transport_failure( - "WaitForInstanceStart", + method_name, grpc.StatusCode.DEADLINE_EXCEEDED, ) is False From 19e3d71ab6d522dbe79619a6009a7f39636453b4 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 19:18:39 -0700 Subject: [PATCH 10/28] Add grpc resiliency edge-case tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_grpc_resiliency.py | 24 +++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/durabletask/test_grpc_resiliency.py b/tests/durabletask/test_grpc_resiliency.py index 5374672..f94f4e1 100644 --- a/tests/durabletask/test_grpc_resiliency.py +++ b/tests/durabletask/test_grpc_resiliency.py @@ -111,6 +111,21 @@ def test_full_jitter_delay_is_capped(monkeypatch): assert delay == 30.0 +def test_full_jitter_delay_large_attempt_is_still_capped(monkeypatch): + monkeypatch.setattr( + "durabletask.internal.grpc_resiliency.random.random", + lambda: 1.0, + ) + + delay = get_full_jitter_delay_seconds( + 1_000, + base_seconds=1.0, + cap_seconds=30.0, + ) + + assert delay == 30.0 + + def test_failure_tracker_trips_at_threshold(): tracker = FailureTracker(threshold=3) @@ -123,6 +138,15 @@ def test_failure_tracker_trips_at_threshold(): assert tracker.consecutive_failures == 0 +def test_failure_tracker_threshold_zero_never_trips(): + tracker = FailureTracker(threshold=0) + + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.consecutive_failures == 0 + + @pytest.mark.parametrize( "method_name", [ From 008be6503b93aa25e13fafee902ed1d5074b3c78 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 19:55:52 -0700 Subject: [PATCH 11/28] Harden worker gRPC stream reconnect behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 6 + durabletask/worker.py | 172 ++++++++--- tests/durabletask/test_worker_resiliency.py | 306 ++++++++++++++++++++ 3 files changed, 449 insertions(+), 35 deletions(-) create mode 100644 tests/durabletask/test_worker_resiliency.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b63ad05..021f2c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,12 @@ ADDED - Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. - Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` so local orchestration tests can retrieve history and page terminal instance IDs by completion window. +FIXED + +- Hardened `TaskHubGrpcWorker` reconnect handling so configured hello timeouts + apply on fresh connections, received work items reset failure tracking, and + caller-owned channels are never recreated during worker reconnects. + ## v1.4.0 ADDED diff --git a/durabletask/worker.py b/durabletask/worker.py index 9dc28b3..fb2aa46 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -6,7 +6,6 @@ import json import logging import os -import random import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field @@ -39,6 +38,11 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared import durabletask.internal.tracing as tracing +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + get_full_jitter_delay_seconds, + is_worker_transport_failure, +) from durabletask.payload import helpers as payload_helpers from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -48,6 +52,7 @@ TOutput = TypeVar("TOutput") DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3) +_STREAM_CLOSED_SENTINEL = object() class ConcurrencyOptions: @@ -118,6 +123,13 @@ class VersionFailureStrategy(Enum): FAIL = 2 +class _WorkItemStreamOutcome(Enum): + SHUTDOWN = "shutdown" + GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE = "graceful_close_before_first_message" + GRACEFUL_CLOSE_AFTER_MESSAGE = "graceful_close_after_message" + SILENT_DISCONNECT = "silent_disconnect" + + class VersioningOptions: """Configuration options for orchestrator and activity versioning. @@ -501,6 +513,27 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.stop() + def _classify_stream_outcome( + self, + *, + saw_message: bool, + timed_out: bool, + ) -> _WorkItemStreamOutcome: + if timed_out: + return _WorkItemStreamOutcome.SILENT_DISCONNECT + if saw_message: + return _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE + return _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE + + def _should_count_worker_failure( + self, + status_code: grpc.StatusCode, + ) -> bool: + return is_worker_transport_failure(status_code) + + def _can_recreate_channel(self) -> bool: + return self._channel is None + def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str: """Registers an orchestrator function with the worker.""" if self._is_running: @@ -602,26 +635,26 @@ def run_loop(): async def _async_run_loop(self): worker_task = asyncio.create_task(self._async_worker_manager.run()) - # Connection state management for retry fix - current_channel = None + current_channel = self._channel current_stub = None current_reader_thread = None conn_retry_count = 0 - conn_max_retry_delay = 60 + failure_tracker = FailureTracker( + threshold=self._resiliency_options.channel_recreate_failure_threshold, + ) + + def get_reconnect_delay_seconds() -> float: + return get_full_jitter_delay_seconds( + conn_retry_count, + base_seconds=self._resiliency_options.reconnect_backoff_base_seconds, + cap_seconds=self._resiliency_options.reconnect_backoff_cap_seconds, + ) def create_fresh_connection(): nonlocal current_channel, current_stub, conn_retry_count - if current_channel and self._channel is None: - try: - current_channel.close() - except Exception: - pass - current_channel = None current_stub = None try: - if self._channel is not None: - current_channel = self._channel - else: + if current_channel is None: current_channel = shared.get_grpc_channel( self._host_address, self._secure_channel, @@ -629,16 +662,16 @@ def create_fresh_connection(): channel_options=self._channel_options, ) current_stub = stubs.TaskHubSidecarServiceStub(current_channel) - current_stub.Hello(empty_pb2.Empty()) + hello_timeout = self._resiliency_options.hello_timeout_seconds + current_stub.Hello(empty_pb2.Empty(), timeout=hello_timeout) conn_retry_count = 0 self._logger.info(f"Created fresh connection to {self._host_address}") except Exception as e: self._logger.warning(f"Failed to create connection: {e}") - current_channel = self._channel if self._channel is not None else None current_stub = None raise - def invalidate_connection(): + def invalidate_connection(*, recreate_channel: bool = False): nonlocal current_channel, current_stub, current_reader_thread # Cancel the response stream first to signal the reader thread to stop if self._response_stream is not None: @@ -658,13 +691,12 @@ def invalidate_connection(): pass current_reader_thread = None - # Close the channel - if current_channel and self._channel is None: + if recreate_channel and current_channel is not None and self._can_recreate_channel(): try: current_channel.close() except Exception: pass - current_channel = self._channel if self._channel is not None else None + current_channel = None current_stub = None def should_invalidate_connection(rpc_error): @@ -682,12 +714,18 @@ def should_invalidate_connection(rpc_error): if current_stub is None: try: create_fresh_connection() - except Exception: + except Exception as ex: + recreate_channel = False + if isinstance(ex, grpc.RpcError): + error_code = ex.code() # type: ignore + if self._should_count_worker_failure(error_code): + recreate_channel = ( + failure_tracker.record_failure() + and self._can_recreate_channel() + ) + invalidate_connection(recreate_channel=recreate_channel) conn_retry_count += 1 - delay = min( - conn_max_retry_delay, - (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1), - ) + delay = get_reconnect_delay_seconds() self._logger.warning( f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})" ) @@ -718,15 +756,18 @@ def should_invalidate_connection(rpc_error): import queue work_item_queue = queue.Queue() + saw_message = False def stream_reader(): try: response_stream = self._response_stream if response_stream is None: + work_item_queue.put(_STREAM_CLOSED_SENTINEL) return for work_item in response_stream: work_item_queue.put(work_item) + work_item_queue.put(_STREAM_CLOSED_SENTINEL) except Exception as e: work_item_queue.put(e) @@ -735,15 +776,43 @@ def stream_reader(): current_reader_thread = threading.Thread(target=stream_reader, daemon=True) current_reader_thread.start() loop = asyncio.get_running_loop() + queue_timeout = ( + self._resiliency_options.silent_disconnect_timeout_seconds or None + ) + stream_outcome = None while not self._shutdown.is_set(): try: - work_item = await loop.run_in_executor( - None, work_item_queue.get + work_item = await asyncio.wait_for( + loop.run_in_executor(None, work_item_queue.get), + timeout=queue_timeout, + ) + except asyncio.TimeoutError: + work_item = None + stream_outcome = self._classify_stream_outcome( + saw_message=saw_message, + timed_out=True, + ) + break + + if work_item is _STREAM_CLOSED_SENTINEL: + stream_outcome = self._classify_stream_outcome( + saw_message=saw_message, + timed_out=False, ) + break + + try: if isinstance(work_item, Exception): raise work_item + + saw_message = True request_type = work_item.WhichOneof("request") self._logger.debug(f'Received "{request_type}" work item') + if work_item.HasField("healthPing"): + failure_tracker.record_success() + continue + + failure_tracker.record_success() if work_item.HasField("orchestratorRequest"): self._async_worker_manager.submit_orchestration( self._execute_orchestrator, @@ -776,22 +845,49 @@ def stream_reader(): stub, work_item.completionToken ) - elif work_item.HasField("healthPing"): - pass else: self._logger.warning( f"Unexpected work item type: {request_type}" ) except Exception as e: self._logger.warning(f"Error in work item stream: {e}") - raise e - current_reader_thread.join(timeout=1) - self._logger.info("Work item stream ended normally") + raise + + if stream_outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE: + self._logger.info( + "Work item stream closed before receiving the first message" + ) + invalidate_connection() + continue + if stream_outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE: + self._logger.info("Work item stream closed after receiving messages") + invalidate_connection() + continue + if stream_outcome is _WorkItemStreamOutcome.SILENT_DISCONNECT: + self._logger.warning( + f"Timed out waiting for work items from {self._host_address}" + ) + recreate_channel = ( + failure_tracker.record_failure() + and self._can_recreate_channel() + ) + invalidate_connection(recreate_channel=recreate_channel) + conn_retry_count += 1 + delay = get_reconnect_delay_seconds() + if self._shutdown.wait(delay): + break + continue except grpc.RpcError as rpc_error: should_invalidate = should_invalidate_connection(rpc_error) - if should_invalidate: - invalidate_connection() error_code = rpc_error.code() # type: ignore + recreate_channel = False + if should_invalidate and self._should_count_worker_failure(error_code): + recreate_channel = ( + failure_tracker.record_failure() + and self._can_recreate_channel() + ) + if should_invalidate: + invalidate_connection(recreate_channel=recreate_channel) error_details = str(rpc_error) if error_code == grpc.StatusCode.CANCELLED: @@ -815,11 +911,17 @@ def stream_reader(): self._logger.warning( f"Application-level gRPC error ({error_code}): {rpc_error}" ) - self._shutdown.wait(1) + conn_retry_count += 1 + delay = get_reconnect_delay_seconds() + if self._shutdown.wait(delay): + break except Exception as ex: invalidate_connection() self._logger.warning(f"Unexpected error: {ex}") - self._shutdown.wait(1) + conn_retry_count += 1 + delay = get_reconnect_delay_seconds() + if self._shutdown.wait(delay): + break invalidate_connection() self._logger.info("No longer listening for work items") self._async_worker_manager.shutdown() diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py new file mode 100644 index 0000000..12cb8c6 --- /dev/null +++ b/tests/durabletask/test_worker_resiliency.py @@ -0,0 +1,306 @@ +import asyncio +import grpc +from unittest.mock import MagicMock + +import pytest + +from durabletask.grpc_options import GrpcWorkerResiliencyOptions +from durabletask.internal import orchestrator_service_pb2 as pb +from durabletask.worker import TaskHubGrpcWorker, _WorkItemStreamOutcome + + +class FakeRpcError(grpc.RpcError): + def __init__(self, status_code: grpc.StatusCode, details: str): + super().__init__() + self._status_code = status_code + self._details = details + + def code(self): + return self._status_code + + def details(self): + return self._details + + def __str__(self): + return self._details + + +class FakeResponseStream: + def __init__(self, items=(), error: grpc.RpcError | None = None): + self._items = list(items) + self._error = error + self.cancelled = False + + def __iter__(self): + yield from self._items + if self._error is not None: + raise self._error + + def cancel(self): + self.cancelled = True + + +class DummyWorkerManager: + def __init__(self): + self._shutdown_event = asyncio.Event() + self.submissions: list[tuple[str, tuple]] = [] + + async def run(self): + await self._shutdown_event.wait() + + def submit_orchestration(self, *args): + self.submissions.append(("orchestrator", args)) + + def submit_activity(self, *args): + self.submissions.append(("activity", args)) + + def submit_entity_batch(self, *args): + self.submissions.append(("entity", args)) + + def shutdown(self): + self._shutdown_event.set() + + +def _make_activity_work_item() -> pb.WorkItem: + return pb.WorkItem( + activityRequest=pb.ActivityRequest( + name="test_activity", + taskId=1, + orchestrationInstance=pb.OrchestrationInstance(instanceId="instance-id"), + ), + completionToken="token", + ) + + +def test_worker_classifies_graceful_close_before_first_message(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(silent_disconnect_timeout_seconds=5.0) + ) + outcome = worker._classify_stream_outcome( + saw_message=False, + timed_out=False, + ) + assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE + + +def test_worker_classifies_graceful_close_after_message(): + worker = TaskHubGrpcWorker() + outcome = worker._classify_stream_outcome( + saw_message=True, + timed_out=False, + ) + assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE + + +def test_worker_classifies_silent_disconnect(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(silent_disconnect_timeout_seconds=5.0) + ) + outcome = worker._classify_stream_outcome( + saw_message=False, + timed_out=True, + ) + assert outcome is _WorkItemStreamOutcome.SILENT_DISCONNECT + + +def test_worker_counts_only_transport_failures_for_recreation(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + assert worker._should_count_worker_failure(grpc.StatusCode.UNAVAILABLE) is True + assert worker._should_count_worker_failure(grpc.StatusCode.UNAUTHENTICATED) is False + + +def test_worker_does_not_recreate_caller_owned_channel(): + worker = TaskHubGrpcWorker(channel=MagicMock()) + assert worker._can_recreate_channel() is False + + +@pytest.mark.asyncio +async def test_worker_applies_configured_hello_timeout(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(hello_timeout_seconds=12.5) + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + stub = MagicMock() + stub.GetWorkItems.side_effect = FakeRpcError(grpc.StatusCode.CANCELLED, "stop") + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", lambda *args, **kwargs: MagicMock()) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", lambda channel: stub) + + await worker._async_run_loop() + + assert stub.Hello.call_args.kwargs["timeout"] == 12.5 + + +@pytest.mark.asyncio +async def test_worker_does_not_recreate_sdk_owned_channel_for_non_transport_setup_errors(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + first_stub = MagicMock() + first_stub.Hello.side_effect = RuntimeError("boom") + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError(grpc.StatusCode.CANCELLED, "stop") + + stubs = [first_stub, second_stub] + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr( + "durabletask.worker.stubs.TaskHubSidecarServiceStub", + lambda channel: stubs.pop(0), + ) + + await worker._async_run_loop() + + assert len(created_channels) == 1 + + +@pytest.mark.asyncio +async def test_worker_recreates_sdk_owned_channel_after_transport_failure_threshold(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "first transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "second transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(created_channels) == 2 + assert stub_channels[0] is created_channels[0] + assert stub_channels[1] is created_channels[0] + assert stub_channels[2] is created_channels[1] + + +@pytest.mark.asyncio +async def test_worker_never_replaces_caller_owned_channel_during_transport_failures(monkeypatch): + provided_channel = MagicMock(name="provided-channel") + worker = TaskHubGrpcWorker( + channel=provided_channel, + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=1), + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + stub_channels = [] + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr( + "durabletask.worker.shared.get_grpc_channel", + lambda *args, **kwargs: pytest.fail("SDK channel factory should not run for caller-owned channels"), + ) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert stub_channels == [provided_channel, provided_channel] + provided_channel.close.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("work_item", "expected_submissions"), + [ + (pb.WorkItem(healthPing=pb.HealthPing()), 0), + (_make_activity_work_item(), 1), + ], +) +async def test_worker_received_messages_reset_failure_tracker(monkeypatch, work_item, expected_submissions): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions( + channel_recreate_failure_threshold=2, + silent_disconnect_timeout_seconds=5.0, + ) + ) + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "first transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream( + items=[work_item], + error=FakeRpcError(grpc.StatusCode.UNAVAILABLE, "second transport failure"), + ))), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr( + "durabletask.worker.stubs.TaskHubSidecarServiceStub", + lambda channel: stubs.pop(0), + ) + + await worker._async_run_loop() + + assert len(created_channels) == 1 + assert len(worker_manager.submissions) == expected_submissions From c4a98e9100960601bc7475183ec4dfe0b7eac243 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 20:16:41 -0700 Subject: [PATCH 12/28] Fix worker channel cleanup on teardown Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 5 +-- durabletask/worker.py | 20 +++++++---- tests/durabletask/test_worker_resiliency.py | 39 +++++++++++++++++++++ 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 021f2c6..7efb50f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,8 +28,9 @@ ADDED FIXED - Hardened `TaskHubGrpcWorker` reconnect handling so configured hello timeouts - apply on fresh connections, received work items reset failure tracking, and - caller-owned channels are never recreated during worker reconnects. + apply on fresh connections, received work items reset failure tracking, + SDK-owned channels are cleaned up on shutdown and full resets, and + caller-owned channels are never recreated or closed during worker reconnects. ## v1.4.0 diff --git a/durabletask/worker.py b/durabletask/worker.py index fb2aa46..460e5f5 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -671,7 +671,11 @@ def create_fresh_connection(): current_stub = None raise - def invalidate_connection(*, recreate_channel: bool = False): + def invalidate_connection( + *, + recreate_channel: bool = False, + close_channel: bool = False, + ): nonlocal current_channel, current_stub, current_reader_thread # Cancel the response stream first to signal the reader thread to stop if self._response_stream is not None: @@ -691,7 +695,11 @@ def invalidate_connection(*, recreate_channel: bool = False): pass current_reader_thread = None - if recreate_channel and current_channel is not None and self._can_recreate_channel(): + if ( + current_channel is not None + and self._can_recreate_channel() + and (recreate_channel or close_channel) + ): try: current_channel.close() except Exception: @@ -857,11 +865,11 @@ def stream_reader(): self._logger.info( "Work item stream closed before receiving the first message" ) - invalidate_connection() + invalidate_connection(close_channel=True) continue if stream_outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE: self._logger.info("Work item stream closed after receiving messages") - invalidate_connection() + invalidate_connection(close_channel=True) continue if stream_outcome is _WorkItemStreamOutcome.SILENT_DISCONNECT: self._logger.warning( @@ -916,13 +924,13 @@ def stream_reader(): if self._shutdown.wait(delay): break except Exception as ex: - invalidate_connection() + invalidate_connection(close_channel=True) self._logger.warning(f"Unexpected error: {ex}") conn_retry_count += 1 delay = get_reconnect_delay_seconds() if self._shutdown.wait(delay): break - invalidate_connection() + invalidate_connection(close_channel=True) self._logger.info("No longer listening for work items") self._async_worker_manager.shutdown() await worker_task diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py index 12cb8c6..9b347b9 100644 --- a/tests/durabletask/test_worker_resiliency.py +++ b/tests/durabletask/test_worker_resiliency.py @@ -213,6 +213,45 @@ def create_stub(channel): assert stub_channels[0] is created_channels[0] assert stub_channels[1] is created_channels[0] assert stub_channels[2] is created_channels[1] + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_closes_sdk_owned_channel_on_graceful_stream_reset(monkeypatch): + worker = TaskHubGrpcWorker() + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + stub_channels = [] + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream())), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(created_channels) == 2 + assert stub_channels == created_channels + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() @pytest.mark.asyncio From 950892b9fc81dfa548d4e89ca2e084ed0e337bb6 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 20:29:21 -0700 Subject: [PATCH 13/28] Add worker silent disconnect tests Extend worker resiliency coverage with an end-to-end silent-disconnect recovery test and an explicit reconnect backoff assertion. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_worker_resiliency.py | 131 ++++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py index 9b347b9..9d3f596 100644 --- a/tests/durabletask/test_worker_resiliency.py +++ b/tests/durabletask/test_worker_resiliency.py @@ -1,5 +1,6 @@ import asyncio import grpc +from threading import Event from unittest.mock import MagicMock import pytest @@ -40,6 +41,22 @@ def cancel(self): self.cancelled = True +class BlockingResponseStream: + def __init__(self): + self._cancel_event = Event() + self.cancelled = False + + def __iter__(self): + if not self._cancel_event.wait(timeout=0.5): + raise AssertionError("response stream was not cancelled") + return + yield + + def cancel(self): + self.cancelled = True + self._cancel_event.set() + + class DummyWorkerManager: def __init__(self): self._shutdown_event = asyncio.Event() @@ -217,6 +234,66 @@ def create_stub(channel): created_channels[1].close.assert_called_once() +@pytest.mark.asyncio +async def test_worker_recreates_sdk_owned_channel_after_silent_disconnect(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions( + channel_recreate_failure_threshold=1, + silent_disconnect_timeout_seconds=0.01, + ) + ) + worker._async_worker_manager = DummyWorkerManager() + + wait_calls = [] + + def shutdown_wait(timeout): + wait_calls.append(timeout) + return False + + monkeypatch.setattr(worker._shutdown, "wait", shutdown_wait) + + delay_calls = [] + + def fake_delay(attempt, *, base_seconds, cap_seconds): + delay_calls.append((attempt, base_seconds, cap_seconds)) + return 0.25 + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + blocking_stream = BlockingResponseStream() + stub_channels = [] + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=blocking_stream)), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.get_full_jitter_delay_seconds", fake_delay) + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert blocking_stream.cancelled is True + assert delay_calls == [(1, 1.0, 30.0)] + assert wait_calls == [0.25] + assert len(created_channels) == 2 + assert stub_channels == created_channels + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + @pytest.mark.asyncio async def test_worker_closes_sdk_owned_channel_on_graceful_stream_reset(monkeypatch): worker = TaskHubGrpcWorker() @@ -254,6 +331,60 @@ def create_stub(channel): created_channels[1].close.assert_called_once() +@pytest.mark.asyncio +async def test_worker_uses_reconnect_backoff_helper_after_connection_failure(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions( + reconnect_backoff_base_seconds=1.5, + reconnect_backoff_cap_seconds=9.0, + ) + ) + worker._async_worker_manager = DummyWorkerManager() + + wait_calls = [] + + def shutdown_wait(timeout): + wait_calls.append(timeout) + return False + + monkeypatch.setattr(worker._shutdown, "wait", shutdown_wait) + + delay_calls = [] + + def fake_delay(attempt, *, base_seconds, cap_seconds): + delay_calls.append((attempt, base_seconds, cap_seconds)) + return 0.75 + + channel = MagicMock(name="channel-1") + stub_channels = [] + first_stub = MagicMock() + first_stub.Hello.side_effect = FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "connect failed", + ) + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + stubs = [first_stub, second_stub] + + def create_stub(current_channel): + stub_channels.append(current_channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.get_full_jitter_delay_seconds", fake_delay) + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", lambda *args, **kwargs: channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert delay_calls == [(1, 1.5, 9.0)] + assert wait_calls == [0.75] + assert stub_channels == [channel, channel] + channel.close.assert_called_once() + + @pytest.mark.asyncio async def test_worker_never_replaces_caller_owned_channel_during_transport_failures(monkeypatch): provided_channel = MagicMock(name="provided-channel") From 09697bbe55844abf36760ec399753e39a59bb8e5 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 20:44:53 -0700 Subject: [PATCH 14/28] Add sync client gRPC channel recreation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 3 + durabletask/client.py | 99 ++++++++++++++++++++----- tests/durabletask/test_client.py | 121 +++++++++++++++++++++++++++++++ 3 files changed, 206 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7efb50f..dbbf9f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ FIXED apply on fresh connections, received work items reset failure tracking, SDK-owned channels are cleaned up on shutdown and full resets, and caller-owned channels are never recreated or closed during worker reconnects. +- Fixed sync `TaskHubGrpcClient` transport resiliency so SDK-owned channels are + recreated after repeated transport failures without counting long-poll + timeout deadlines against the recreation threshold. ## v1.4.0 diff --git a/durabletask/client.py b/durabletask/client.py index fa3a924..45ab350 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. import logging +import threading +import time import uuid from dataclasses import dataclass from datetime import datetime @@ -25,6 +27,10 @@ import durabletask.internal.shared as shared import durabletask.internal.tracing as tracing from durabletask import task +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + is_client_transport_failure, +) from durabletask.internal.client_helpers import ( build_query_entities_req, build_query_instances_req, @@ -201,10 +207,61 @@ def __init__(self, *, ) self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._client_failure_tracker = FailureTracker( + self._resiliency_options.channel_recreate_failure_threshold + ) + self._last_recreate_time = 0.0 + self._recreate_lock = threading.Lock() self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store + def _invoke_unary( + self, + method_name: str, + request: Any, + *, + timeout: Optional[int] = None): + method = getattr(self._stub, method_name) + try: + if timeout is None: + response = method(request) + else: + response = method(request, timeout=timeout) + except grpc.RpcError as rpc_error: + status_code = rpc_error.code() + if is_client_transport_failure(method_name, status_code): + should_recreate = self._client_failure_tracker.record_failure() + if should_recreate: + self._maybe_recreate_channel() + elif status_code != grpc.StatusCode.DEADLINE_EXCEEDED: + self._client_failure_tracker.record_success() + raise + else: + self._client_failure_tracker.record_success() + return response + + def _maybe_recreate_channel(self) -> None: + if not self._owns_channel: + return + with self._recreate_lock: + now = time.monotonic() + if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: + return + old_channel = self._channel + self._channel = shared.get_grpc_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._last_recreate_time = now + self._client_failure_tracker.record_success() + close_timer = threading.Timer(30.0, old_channel.close) + close_timer.daemon = True + close_timer.start() + def close(self) -> None: """Closes the underlying gRPC channel. @@ -249,12 +306,12 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu payload_helpers.externalize_payloads( req, self._payload_store, instance_id=req.instanceId, ) - res: pb.CreateInstanceResponse = self._stub.StartInstance(req) + res: pb.CreateInstanceResponse = self._invoke_unary("StartInstance", req) return res.instanceId def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - res: pb.GetInstanceResponse = self._stub.GetInstance(req) + res: pb.GetInstanceResponse = self._invoke_unary("GetInstance", req) # De-externalize any large-payload tokens in the response if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) @@ -294,7 +351,7 @@ def list_instance_ids(self, f"page_size={page_size}, " f"continuation_token={continuation_token}" ) - resp: pb.ListInstanceIdsResponse = self._stub.ListInstanceIds(req) + resp: pb.ListInstanceIdsResponse = self._invoke_unary("ListInstanceIds", req) next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None return Page(items=list(resp.instanceIds), continuation_token=next_token) @@ -311,7 +368,7 @@ def get_all_orchestration_states(self, while True: req = build_query_instances_req(orchestration_query, _continuation_token) - resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req) + resp: pb.QueryInstancesResponse = self._invoke_unary("QueryInstances", req) if self._payload_store is not None: payload_helpers.deexternalize_payloads(resp, self._payload_store) states += [parse_orchestration_state(res) for res in resp.orchestrationState] @@ -328,7 +385,11 @@ def wait_for_orchestration_start(self, instance_id: str, *, req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout) + res: pb.GetInstanceResponse = self._invoke_unary( + "WaitForInstanceStart", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) return new_orchestration_state(req.instanceId, res) @@ -345,7 +406,11 @@ def wait_for_orchestration_completion(self, instance_id: str, *, req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout) + res: pb.GetInstanceResponse = self._invoke_unary( + "WaitForInstanceCompletion", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) state = new_orchestration_state(req.instanceId, res) @@ -366,7 +431,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, payload_helpers.externalize_payloads( req, self._payload_store, instance_id=instance_id, ) - self._stub.RaiseEvent(req) + self._invoke_unary("RaiseEvent", req) def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, @@ -378,17 +443,17 @@ def terminate_orchestration(self, instance_id: str, *, payload_helpers.externalize_payloads( req, self._payload_store, instance_id=instance_id, ) - self._stub.TerminateInstance(req) + self._invoke_unary("TerminateInstance", req) def suspend_orchestration(self, instance_id: str) -> None: req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") - self._stub.SuspendInstance(req) + self._invoke_unary("SuspendInstance", req) def resume_orchestration(self, instance_id: str) -> None: req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") - self._stub.ResumeInstance(req) + self._invoke_unary("ResumeInstance", req) def restart_orchestration(self, instance_id: str, *, restart_with_new_instance_id: bool = False) -> str: @@ -407,13 +472,13 @@ def restart_orchestration(self, instance_id: str, *, restartWithNewInstanceId=restart_with_new_instance_id) self._logger.info(f"Restarting instance '{instance_id}'.") - res: pb.RestartInstanceResponse = self._stub.RestartInstance(req) + res: pb.RestartInstanceResponse = self._invoke_unary("RestartInstance", req) return res.instanceId def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") - resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) def purge_orchestrations_by(self, @@ -427,7 +492,7 @@ def purge_orchestrations_by(self, f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " f"recursive={recursive}") req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive) - resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) def signal_entity(self, @@ -440,7 +505,7 @@ def signal_entity(self, payload_helpers.externalize_payloads( req, self._payload_store, instance_id=str(entity_instance_id), ) - self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? + self._invoke_unary("SignalEntity", req) # TODO: Cancellation timeout? def get_entity(self, entity_instance_id: EntityInstanceId, @@ -448,7 +513,7 @@ def get_entity(self, ) -> Optional[EntityMetadata]: req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state) self._logger.info(f"Getting entity '{entity_instance_id}'.") - res: pb.GetEntityResponse = self._stub.GetEntity(req) + res: pb.GetEntityResponse = self._invoke_unary("GetEntity", req) if not res.exists: return None if self._payload_store is not None: @@ -467,7 +532,7 @@ def get_all_entities(self, while True: query_request = build_query_entities_req(entity_query, _continuation_token) - resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request) + resp: pb.QueryEntitiesResponse = self._invoke_unary("QueryEntities", query_request) if self._payload_store is not None: payload_helpers.deexternalize_payloads(resp, self._payload_store) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] @@ -493,7 +558,7 @@ def clean_entity_storage(self, releaseOrphanedLocks=release_orphaned_locks, continuationToken=_continuation_token ) - resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req) + resp: pb.CleanEntityStorageResponse = self._invoke_unary("CleanEntityStorage", req) empty_entities_removed += resp.emptyEntitiesRemoved orphaned_locks_released += resp.orphanedLocksReleased diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index a4a1b5b..ef4ce79 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,4 +1,5 @@ import json +import grpc import pytest from datetime import datetime, timezone from unittest.mock import ANY, AsyncMock, MagicMock, patch @@ -32,6 +33,15 @@ INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] +class FakeRpcError(grpc.RpcError): + def __init__(self, status_code: grpc.StatusCode): + super().__init__() + self._status_code = status_code + + def code(self): + return self._status_code + + class FakePayloadStore(PayloadStore): TOKEN_PREFIX = 'fake://' @@ -317,6 +327,117 @@ def test_client_stores_resiliency_options_for_recreation(): assert client._interceptors == interceptors +def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + first_stub = MagicMock() + second_stub = MagicMock() + second_stub.GetInstance.return_value = MagicMock(exists=False) + + rpc_error = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + first_stub.GetInstance.side_effect = rpc_error + + timer = MagicMock() + + with patch("durabletask.client.shared.get_grpc_channel", side_effect=[first_channel, second_channel]), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ), patch("threading.Timer", return_value=timer) as mock_timer: + client = TaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ), + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + client.get_orchestration_state("abc") + + assert client._channel is second_channel + mock_timer.assert_called_once_with(30.0, first_channel.close) + assert timer.daemon is True + timer.start.assert_called_once_with() + + +def test_sync_client_does_not_count_long_poll_deadline(): + stub = MagicMock() + stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + stub.WaitForInstanceStart.side_effect = FakeRpcError(grpc.StatusCode.DEADLINE_EXCEEDED) + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(TimeoutError): + client.wait_for_orchestration_start("abc") + assert client._client_failure_tracker.consecutive_failures == 1 + + +def test_sync_client_does_not_recreate_caller_owned_channel(): + provided_channel = MagicMock(name="provided-channel") + stub = MagicMock() + stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + + with patch("durabletask.client.shared.get_grpc_channel") as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ) as mock_stub: + client = TaskHubGrpcClient( + channel=provided_channel, + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1), + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + + assert client._channel is provided_channel + mock_get_channel.assert_not_called() + mock_stub.assert_called_once_with(provided_channel) + + +def test_sync_client_resets_failure_tracking_after_success(): + stub = MagicMock() + stub.GetInstance.side_effect = [ + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + MagicMock(exists=False), + ] + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client.get_orchestration_state("abc") is None + assert client._client_failure_tracker.consecutive_failures == 0 + + +def test_sync_client_resets_failure_tracking_after_application_error(): + stub = MagicMock() + stub.GetInstance.side_effect = [ + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT), + ] + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._client_failure_tracker.consecutive_failures == 0 + + def test_async_client_stores_resolved_transport_inputs(): resiliency = GrpcClientResiliencyOptions() channel_options = GrpcChannelOptions(max_send_message_length=4321) From d89bc10ab91c114f5557fcac8ce099168f0ce155 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 21:16:01 -0700 Subject: [PATCH 15/28] Reset sync client long-poll failure tracking Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 12 ++++++++---- durabletask/client.py | 2 +- tests/durabletask/test_client.py | 18 ++++++++++++++---- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbbf9f6..7020125 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,8 +22,11 @@ ADDED - Added optional `resiliency_options` parameters to `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` so applications can pass gRPC resiliency settings through constructor APIs. -- Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. -- Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` so local orchestration tests can retrieve history and page terminal instance IDs by completion window. +- Added `get_orchestration_history()` and `list_instance_ids()` to the sync + and async gRPC clients. +- Added in-memory backend support for `StreamInstanceHistory` and + `ListInstanceIds` so local orchestration tests can retrieve history and page + terminal instance IDs by completion window. FIXED @@ -32,8 +35,9 @@ FIXED SDK-owned channels are cleaned up on shutdown and full resets, and caller-owned channels are never recreated or closed during worker reconnects. - Fixed sync `TaskHubGrpcClient` transport resiliency so SDK-owned channels are - recreated after repeated transport failures without counting long-poll - timeout deadlines against the recreation threshold. + recreated after repeated transport failures while long-poll timeout + deadlines, successful replies, and application-level RPC errors reset the + failure tracker. ## v1.4.0 diff --git a/durabletask/client.py b/durabletask/client.py index 45ab350..6988dd4 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -234,7 +234,7 @@ def _invoke_unary( should_recreate = self._client_failure_tracker.record_failure() if should_recreate: self._maybe_recreate_channel() - elif status_code != grpc.StatusCode.DEADLINE_EXCEEDED: + else: self._client_failure_tracker.record_success() raise else: diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index ef4ce79..b20516d 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -359,10 +359,20 @@ def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable(): timer.start.assert_called_once_with() -def test_sync_client_does_not_count_long_poll_deadline(): +@pytest.mark.parametrize( + ("stub_method_name", "client_method_name"), + [ + ("WaitForInstanceStart", "wait_for_orchestration_start"), + ("WaitForInstanceCompletion", "wait_for_orchestration_completion"), + ], +) +def test_sync_client_resets_failure_tracking_after_long_poll_deadline( + stub_method_name: str, + client_method_name: str, +): stub = MagicMock() stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) - stub.WaitForInstanceStart.side_effect = FakeRpcError(grpc.StatusCode.DEADLINE_EXCEEDED) + getattr(stub, stub_method_name).side_effect = FakeRpcError(grpc.StatusCode.DEADLINE_EXCEEDED) with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub @@ -373,8 +383,8 @@ def test_sync_client_does_not_count_long_poll_deadline(): with pytest.raises(FakeRpcError): client.get_orchestration_state("abc") with pytest.raises(TimeoutError): - client.wait_for_orchestration_start("abc") - assert client._client_failure_tracker.consecutive_failures == 1 + getattr(client, client_method_name)("abc") + assert client._client_failure_tracker.consecutive_failures == 0 def test_sync_client_does_not_recreate_caller_owned_channel(): From 937231d1c5c346c106d629c745695c16105534c9 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 21:34:44 -0700 Subject: [PATCH 16/28] Add sync client recreation input test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_client.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index b20516d..4c8fc45 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -2,7 +2,7 @@ import grpc import pytest from datetime import datetime, timezone -from unittest.mock import ANY, AsyncMock, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch from google.protobuf import wrappers_pb2 @@ -327,23 +327,32 @@ def test_client_stores_resiliency_options_for_recreation(): assert client._interceptors == interceptors -def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable(): +def test_sync_client_recreates_sdk_owned_channel_with_original_transport_inputs(): first_channel = MagicMock(name="first-channel") second_channel = MagicMock(name="second-channel") first_stub = MagicMock() second_stub = MagicMock() second_stub.GetInstance.return_value = MagicMock(exists=False) + host_address = "localhost:4001" + interceptors = [DefaultClientInterceptorImpl(METADATA)] + channel_options = GrpcChannelOptions(max_receive_message_length=1234) rpc_error = FakeRpcError(grpc.StatusCode.UNAVAILABLE) first_stub.GetInstance.side_effect = rpc_error timer = MagicMock() - with patch("durabletask.client.shared.get_grpc_channel", side_effect=[first_channel, second_channel]), patch( + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel], + ) as mock_get_channel, patch( "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] ), patch("threading.Timer", return_value=timer) as mock_timer: client = TaskHubGrpcClient( - host_address="localhost:4001", + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, resiliency_options=GrpcClientResiliencyOptions( channel_recreate_failure_threshold=1, min_recreate_interval_seconds=0.0, @@ -353,6 +362,16 @@ def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable(): client.get_orchestration_state("abc") client.get_orchestration_state("abc") + expected_channel_call = call( + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + ) + assert mock_get_channel.call_args_list == [ + expected_channel_call, + expected_channel_call, + ] assert client._channel is second_channel mock_timer.assert_called_once_with(30.0, first_channel.close) assert timer.daemon is True From 2834177d84d1db872142e45e986c9e215fe3ef88 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 21:45:58 -0700 Subject: [PATCH 17/28] Add sync client recreation test coverage Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_client.py | 70 +++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 4c8fc45..ce2c253 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -413,7 +413,7 @@ def test_sync_client_does_not_recreate_caller_owned_channel(): with patch("durabletask.client.shared.get_grpc_channel") as mock_get_channel, patch( "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub - ) as mock_stub: + ) as mock_stub, patch("threading.Timer") as mock_timer: client = TaskHubGrpcClient( channel=provided_channel, resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1), @@ -426,6 +426,74 @@ def test_sync_client_does_not_recreate_caller_owned_channel(): assert client._channel is provided_channel mock_get_channel.assert_not_called() mock_stub.assert_called_once_with(provided_channel) + mock_timer.assert_not_called() + + +def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + third_channel = MagicMock(name="third-channel") + first_stub = MagicMock() + second_stub = MagicMock() + third_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub.GetInstance.side_effect = [ + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + ] + timer1 = MagicMock(name="close-timer-1") + timer2 = MagicMock(name="close-timer-2") + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel, third_channel], + ) as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", + side_effect=[first_stub, second_stub, third_stub], + ), patch( + "durabletask.client.time.monotonic", side_effect=[100.0, 101.0, 131.0] + ), patch("threading.Timer", side_effect=[timer1, timer2]) as mock_timer: + client = TaskHubGrpcClient( + host_address=HOST_ADDRESS, + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=30.0, + ), + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._channel is second_channel + assert mock_get_channel.call_count == 2 + + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._channel is second_channel + assert mock_get_channel.call_count == 2 + mock_timer.assert_called_once_with(30.0, first_channel.close) + + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._channel is third_channel + + expected_channel_call = call( + host_address=HOST_ADDRESS, + secure_channel=False, + interceptors=None, + channel_options=None, + ) + assert mock_get_channel.call_args_list == [ + expected_channel_call, + expected_channel_call, + expected_channel_call, + ] + assert mock_timer.call_args_list == [ + call(30.0, first_channel.close), + call(30.0, second_channel.close), + ] + assert timer1.daemon is True + assert timer2.daemon is True + timer1.start.assert_called_once_with() + timer2.start.assert_called_once_with() def test_sync_client_resets_failure_tracking_after_success(): From 4f4b3a3273a0bb27d0318edd86b4ae4e00a77482 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 22:28:43 -0700 Subject: [PATCH 18/28] Add async client gRPC channel recreation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 4 + durabletask/client.py | 121 +++++++++++++++++++---- tests/durabletask/test_client.py | 164 +++++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7020125..873812b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,10 @@ FIXED recreated after repeated transport failures while long-poll timeout deadlines, successful replies, and application-level RPC errors reset the failure tracker. +- Fixed async `AsyncTaskHubGrpcClient` transport resiliency so SDK-owned + channels are recreated after repeated transport failures while long-poll + timeout deadlines, successful replies, and application-level RPC errors + reset the failure tracker. ## v1.4.0 diff --git a/durabletask/client.py b/durabletask/client.py index 6988dd4..1648478 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import asyncio import logging import threading import time @@ -614,6 +615,14 @@ def __init__(self, *, ) self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._client_failure_tracker = FailureTracker( + self._resiliency_options.channel_recreate_failure_threshold + ) + self._closing = False + self._recreate_lock = asyncio.Lock() + self._last_recreate_time = 0.0 + self._retired_channels: list[grpc.aio.Channel] = [] + self._retired_channel_close_tasks: set[asyncio.Task[None]] = set() self._logger = shared.get_logger("async_client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store @@ -627,6 +636,18 @@ async def close(self) -> None: it. """ if self._owns_channel: + self._closing = True + async with self._recreate_lock: + retired_channels = list(self._retired_channels) + self._retired_channels.clear() + close_tasks = list(self._retired_channel_close_tasks) + self._retired_channel_close_tasks.clear() + for close_task in close_tasks: + close_task.cancel() + if close_tasks: + await asyncio.gather(*close_tasks, return_exceptions=True) + for retired_channel in retired_channels: + await retired_channel.close() await self._channel.close() async def __aenter__(self): @@ -635,6 +656,64 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() + async def _invoke_unary( + self, + method_name: str, + request: Any, + *, + timeout: Optional[int] = None): + method = getattr(self._stub, method_name) + try: + if timeout is None: + response = await method(request) + else: + response = await method(request, timeout=timeout) + except grpc.aio.AioRpcError as rpc_error: + if is_client_transport_failure(method_name, rpc_error.code()): + should_recreate = self._client_failure_tracker.record_failure() + if should_recreate: + await self._maybe_recreate_channel() + else: + self._client_failure_tracker.record_success() + raise + else: + self._client_failure_tracker.record_success() + return response + + async def _maybe_recreate_channel(self) -> None: + if not self._owns_channel or self._closing: + return + async with self._recreate_lock: + if self._closing: + return + now = time.monotonic() + if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: + return + old_channel = self._channel + self._channel = shared.get_async_grpc_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._last_recreate_time = now + self._client_failure_tracker.record_success() + self._retired_channels.append(old_channel) + close_task = asyncio.create_task(self._close_retired_channel(old_channel)) + self._retired_channel_close_tasks.add(close_task) + close_task.add_done_callback(self._retired_channel_close_tasks.discard) + + async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None: + try: + await asyncio.sleep(30.0) + await channel.close() + finally: + try: + self._retired_channels.remove(channel) + except ValueError: + pass + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -665,13 +744,13 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=req.instanceId, ) - res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + res: pb.CreateInstanceResponse = await self._invoke_unary("StartInstance", req) return res.instanceId async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + res: pb.GetInstanceResponse = await self._invoke_unary("GetInstance", req) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) @@ -710,7 +789,7 @@ async def list_instance_ids(self, f"page_size={page_size}, " f"continuation_token={continuation_token}" ) - resp: pb.ListInstanceIdsResponse = await self._stub.ListInstanceIds(req) + resp: pb.ListInstanceIdsResponse = await self._invoke_unary("ListInstanceIds", req) next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None return Page(items=list(resp.instanceIds), continuation_token=next_token) @@ -727,7 +806,7 @@ async def get_all_orchestration_states(self, while True: req = build_query_instances_req(orchestration_query, _continuation_token) - resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req) + resp: pb.QueryInstancesResponse = await self._invoke_unary("QueryInstances", req) if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) states += [parse_orchestration_state(res) for res in resp.orchestrationState] @@ -744,7 +823,11 @@ async def wait_for_orchestration_start(self, instance_id: str, *, req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") - res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout) + res: pb.GetInstanceResponse = await self._invoke_unary( + "WaitForInstanceStart", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) @@ -760,7 +843,11 @@ async def wait_for_orchestration_completion(self, instance_id: str, *, req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") - res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout) + res: pb.GetInstanceResponse = await self._invoke_unary( + "WaitForInstanceCompletion", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) state = new_orchestration_state(req.instanceId, res) @@ -781,7 +868,7 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *, await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=instance_id, ) - await self._stub.RaiseEvent(req) + await self._invoke_unary("RaiseEvent", req) async def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, @@ -793,17 +880,17 @@ async def terminate_orchestration(self, instance_id: str, *, await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=instance_id, ) - await self._stub.TerminateInstance(req) + await self._invoke_unary("TerminateInstance", req) async def suspend_orchestration(self, instance_id: str) -> None: req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") - await self._stub.SuspendInstance(req) + await self._invoke_unary("SuspendInstance", req) async def resume_orchestration(self, instance_id: str) -> None: req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") - await self._stub.ResumeInstance(req) + await self._invoke_unary("ResumeInstance", req) async def restart_orchestration(self, instance_id: str, *, restart_with_new_instance_id: bool = False) -> str: @@ -822,13 +909,13 @@ async def restart_orchestration(self, instance_id: str, *, restartWithNewInstanceId=restart_with_new_instance_id) self._logger.info(f"Restarting instance '{instance_id}'.") - res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req) + res: pb.RestartInstanceResponse = await self._invoke_unary("RestartInstance", req) return res.instanceId async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") - resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = await self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) async def purge_orchestrations_by(self, @@ -842,7 +929,7 @@ async def purge_orchestrations_by(self, f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " f"recursive={recursive}") req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive) - resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = await self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) async def signal_entity(self, @@ -855,7 +942,7 @@ async def signal_entity(self, await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=str(entity_instance_id), ) - await self._stub.SignalEntity(req, None) + await self._invoke_unary("SignalEntity", req) async def get_entity(self, entity_instance_id: EntityInstanceId, @@ -863,7 +950,7 @@ async def get_entity(self, ) -> Optional[EntityMetadata]: req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state) self._logger.info(f"Getting entity '{entity_instance_id}'.") - res: pb.GetEntityResponse = await self._stub.GetEntity(req) + res: pb.GetEntityResponse = await self._invoke_unary("GetEntity", req) if not res.exists: return None if self._payload_store is not None: @@ -882,7 +969,7 @@ async def get_all_entities(self, while True: query_request = build_query_entities_req(entity_query, _continuation_token) - resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request) + resp: pb.QueryEntitiesResponse = await self._invoke_unary("QueryEntities", query_request) if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] @@ -908,7 +995,7 @@ async def clean_entity_storage(self, releaseOrphanedLocks=release_orphaned_locks, continuationToken=_continuation_token ) - resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req) + resp: pb.CleanEntityStorageResponse = await self._invoke_unary("CleanEntityStorage", req) empty_entities_removed += resp.emptyEntitiesRemoved orphaned_locks_released += resp.orphanedLocksReleased diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index ce2c253..95bb768 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,3 +1,4 @@ +import asyncio import json import grpc import pytest @@ -42,6 +43,14 @@ def code(self): return self._status_code +def make_aio_rpc_error(status_code: grpc.StatusCode) -> grpc.aio.AioRpcError: + return grpc.aio.AioRpcError( + status_code, + grpc.aio.Metadata(), + grpc.aio.Metadata(), + ) + + class FakePayloadStore(PayloadStore): TOKEN_PREFIX = 'fake://' @@ -556,6 +565,161 @@ def test_async_client_stores_resolved_transport_inputs(): assert client._interceptors == interceptors +@pytest.mark.asyncio +async def test_async_client_recreates_sdk_owned_channel_after_unavailable(): + rpc_error = make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE) + + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=rpc_error) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + + with patch("durabletask.client.shared.get_async_grpc_channel", side_effect=[MagicMock(), MagicMock()]), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ): + client = AsyncTaskHubGrpcClient( + host_address="localhost:4001", + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ), + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + await client.get_orchestration_state("abc") + + +@pytest.mark.asyncio +async def test_async_client_does_not_count_wait_for_orchestration_deadline(): + stub = MagicMock() + stub.GetInstance = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE)) + stub.WaitForInstanceCompletion = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.DEADLINE_EXCEEDED)) + + with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + with pytest.raises(TimeoutError): + await client.wait_for_orchestration_completion("abc") + assert client._client_failure_tracker.consecutive_failures == 0 + + +@pytest.mark.asyncio +async def test_async_client_close_closes_retired_channels_immediately(): + rpc_error = make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE) + first_channel = MagicMock(name="first-channel") + first_channel.close = AsyncMock() + second_channel = MagicMock(name="second-channel") + second_channel.close = AsyncMock() + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=rpc_error) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + cleanup_started = asyncio.Event() + release_cleanup = asyncio.Event() + + async def blocked_close_retired_channel(self, channel): + cleanup_started.set() + await release_cleanup.wait() + await channel.close() + + with patch( + "durabletask.client.shared.get_async_grpc_channel", + side_effect=[first_channel, second_channel], + ), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ), patch.object( + AsyncTaskHubGrpcClient, + "_close_retired_channel", + new=blocked_close_retired_channel, + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + await cleanup_started.wait() + + try: + await client.close() + first_channel.close.assert_awaited_once() + second_channel.close.assert_awaited_once() + finally: + release_cleanup.set() + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_async_client_does_not_recreate_caller_owned_channel(): + provided_channel = MagicMock(name="provided-channel") + stub = MagicMock() + stub.GetInstance = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE)) + + with patch("durabletask.client.shared.get_async_grpc_channel") as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = AsyncTaskHubGrpcClient( + channel=provided_channel, + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1), + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + + assert client._channel is provided_channel + mock_get_channel.assert_not_called() + + +@pytest.mark.asyncio +async def test_async_client_close_prevents_channel_recreation_race(): + first_channel = MagicMock(name="first-channel") + first_channel.close = AsyncMock() + second_channel = MagicMock(name="second-channel") + second_channel.close = AsyncMock() + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE)) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + + with patch( + "durabletask.client.shared.get_async_grpc_channel", + side_effect=[first_channel, second_channel], + ) as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + await client._recreate_lock.acquire() + try: + rpc_task = asyncio.create_task(client.get_orchestration_state("abc")) + while first_stub.GetInstance.await_count == 0: + await asyncio.sleep(0) + close_task = asyncio.create_task(client.close()) + await asyncio.sleep(0) + finally: + client._recreate_lock.release() + + with pytest.raises(grpc.aio.AioRpcError): + await rpc_task + await close_task + + assert mock_get_channel.call_count == 1 + first_channel.close.assert_awaited_once() + second_channel.close.assert_not_awaited() + + def test_worker_stores_resiliency_options(): resiliency = GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=9) worker = TaskHubGrpcWorker(resiliency_options=resiliency) From 762b24780bc36b537cc5a535752bb382fa2df551 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 22:45:58 -0700 Subject: [PATCH 19/28] Add async channel recreation transport test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/durabletask/test_client.py | 59 +++++++++++++++++--------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 95bb768..724e984 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -544,49 +544,54 @@ def test_sync_client_resets_failure_tracking_after_application_error(): assert client._client_failure_tracker.consecutive_failures == 0 -def test_async_client_stores_resolved_transport_inputs(): - resiliency = GrpcClientResiliencyOptions() - channel_options = GrpcChannelOptions(max_send_message_length=4321) - interceptors = [DefaultAsyncClientInterceptorImpl(METADATA)] - with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() - ): - client = AsyncTaskHubGrpcClient( - host_address="localhost:4001", - secure_channel=True, - interceptors=interceptors, - channel_options=channel_options, - resiliency_options=resiliency, - ) - assert client._resiliency_options is resiliency - assert client._host_address == "localhost:4001" - assert client._secure_channel is True - assert client._channel_options is channel_options - assert client._interceptors == interceptors - - @pytest.mark.asyncio -async def test_async_client_recreates_sdk_owned_channel_after_unavailable(): +async def test_async_client_recreates_sdk_owned_channel_with_original_transport_inputs(): rpc_error = make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE) - + first_channel = MagicMock(name="first-channel") + first_channel.close = AsyncMock() + second_channel = MagicMock(name="second-channel") + second_channel.close = AsyncMock() first_stub = MagicMock() first_stub.GetInstance = AsyncMock(side_effect=rpc_error) second_stub = MagicMock() second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + host_address = "localhost:4001" + interceptors = [DefaultAsyncClientInterceptorImpl(METADATA)] + channel_options = GrpcChannelOptions(max_send_message_length=4321) - with patch("durabletask.client.shared.get_async_grpc_channel", side_effect=[MagicMock(), MagicMock()]), patch( + with patch( + "durabletask.client.shared.get_async_grpc_channel", + side_effect=[first_channel, second_channel], + ) as mock_get_channel, patch( "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] ): client = AsyncTaskHubGrpcClient( - host_address="localhost:4001", + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, resiliency_options=GrpcClientResiliencyOptions( channel_recreate_failure_threshold=1, min_recreate_interval_seconds=0.0, ), ) - with pytest.raises(grpc.aio.AioRpcError): + try: + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") await client.get_orchestration_state("abc") - await client.get_orchestration_state("abc") + finally: + await client.close() + + expected_channel_call = call( + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + ) + assert mock_get_channel.call_args_list == [ + expected_channel_call, + expected_channel_call, + ] @pytest.mark.asyncio From 23639265a3f53b7428e9a7fca1fa1337b3da2544 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 23:01:08 -0700 Subject: [PATCH 20/28] Add gRPC connection resiliency Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 30 ++++++++----------- .../2026-04-23-grpc-resiliency-design.md | 28 +++++++++-------- durabletask-azuremanaged/CHANGELOG.md | 6 ++-- pyproject.toml | 1 + 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 873812b..80cb2cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,16 +12,15 @@ ADDED - Added `GrpcChannelOptions` and `GrpcRetryPolicyOptions` for configuring gRPC transport behavior, including message-size limits, keepalive settings, and channel-level retry policy service configuration. -- Added `GrpcWorkerResiliencyOptions` and `GrpcClientResiliencyOptions` for - configuring public gRPC reconnect, hello timeout, and channel recreation - thresholds. - Added optional `channel` and `channel_options` parameters to `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` to support pre-configured channel passthrough and low-level gRPC channel customization. -- Added optional `resiliency_options` parameters to `TaskHubGrpcClient`, - `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` so applications can pass - gRPC resiliency settings through constructor APIs. +- Added `GrpcWorkerResiliencyOptions` and `GrpcClientResiliencyOptions`, plus + `resiliency_options` constructor parameters on `TaskHubGrpcClient`, + `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker`, to configure hello + deadlines, silent-disconnect detection, reconnect backoff, and channel + recreation thresholds for SDK-managed gRPC connections. - Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. - Added in-memory backend support for `StreamInstanceHistory` and @@ -30,18 +29,13 @@ ADDED FIXED -- Hardened `TaskHubGrpcWorker` reconnect handling so configured hello timeouts - apply on fresh connections, received work items reset failure tracking, - SDK-owned channels are cleaned up on shutdown and full resets, and - caller-owned channels are never recreated or closed during worker reconnects. -- Fixed sync `TaskHubGrpcClient` transport resiliency so SDK-owned channels are - recreated after repeated transport failures while long-poll timeout - deadlines, successful replies, and application-level RPC errors reset the - failure tracker. -- Fixed async `AsyncTaskHubGrpcClient` transport resiliency so SDK-owned - channels are recreated after repeated transport failures while long-poll - timeout deadlines, successful replies, and application-level RPC errors - reset the failure tracker. +- Improved `TaskHubGrpcWorker` recovery from stale or disconnected gRPC streams + so configured hello timeouts apply on fresh connections, received work resets + failure tracking, SDK-owned channels are refreshed and cleaned up safely, and + caller-owned channels are never recreated or closed during reconnects. +- Improved sync and async gRPC clients so repeated transport failures recreate + SDK-owned channels, while long-poll deadlines, successful replies, and + application-level RPC errors do not trigger unnecessary channel replacement. ## v1.4.0 diff --git a/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md b/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md index a2710e2..ce1b586 100644 --- a/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md +++ b/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md @@ -135,14 +135,16 @@ The monitor reports one of these outcomes: The outer worker loop uses those outcomes as follows: - `message_received`: reset health counters -- `graceful_close_before_first_message`: count as channel poison -- `graceful_close_after_message`: reconnect immediately without poisoning the - channel +- `graceful_close_before_first_message`: immediately reset the current stream + and force a fresh SDK-owned channel on the next connect attempt +- `graceful_close_after_message`: immediately reset the current stream and + reconnect without incrementing the transport-failure counter - `silent_disconnect`: count as channel poison - `shutdown`: exit cleanly -This keeps rolling upgrades and normal peer-driven reconnects from being -treated the same as a stale half-open stream. +This keeps rolling upgrades and normal peer-driven reconnects from inflating +the failure threshold while still forcing SDK-owned workers to establish a +fresh channel after graceful stream closures. #### Failure counting and recreation @@ -150,9 +152,8 @@ The worker increments the consecutive-failure counter only for transport-shaped failures: - `UNAVAILABLE` -- `Hello` `DEADLINE_EXCEEDED` +- `DEADLINE_EXCEEDED` - explicit silent-disconnect timeout -- graceful stream close before the first message It does not increment the counter for errors that channel recreation is unlikely to fix, such as: @@ -160,10 +161,13 @@ unlikely to fix, such as: - `UNAUTHENTICATED` - `NOT_FOUND` - orchestration or activity execution failures +- graceful stream closures before or after work items When the threshold is reached and the worker owns the channel, it recreates the -channel and stub. When the worker does not own the channel, it keeps retrying -the existing transport and logs that the channel could not be recreated. +channel and stub. Graceful stream closures also force an immediate fresh +SDK-owned channel even though they do not increment the threshold. When the +worker does not own the channel, it keeps retrying the existing transport and +logs that the channel could not be recreated. ### Client behavior @@ -284,9 +288,9 @@ Add focused unit tests for the new behavior. - hello deadline failure counts toward recreation - silent-disconnect timeout is detected and classified -- graceful close before the first message poisons the channel -- graceful close after a message triggers reconnect without poisoning -- user-supplied channels are not recreated +- graceful stream closes force a fresh SDK-owned connection without increasing + the failure counter +- user-supplied channels are not recreated or closed ### Client tests diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 7a0408f..ac2af0b 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and `DurableTaskSchedulerWorker` to allow combining custom gRPC interceptors with DTS defaults and to support pre-configured/customized gRPC channels. -- Added optional `resiliency_options` parameters to +- Added pass-through `resiliency_options` support on `DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and - `DurableTaskSchedulerWorker` so applications can pass gRPC resiliency - settings through their constructors. + `DurableTaskSchedulerWorker` so Azure Managed applications can use the core + SDK's gRPC resiliency option types through their constructors. - Added `workerid` gRPC metadata on Durable Task Scheduler worker calls for improved worker identity and observability. - Improved sync access token refresh concurrency handling to avoid duplicate diff --git a/pyproject.toml b/pyproject.toml index 145b930..69843ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ include = ["durabletask", "durabletask.*"] minversion = "6.0" testpaths = ["tests"] asyncio_mode = "auto" +addopts = "--import-mode=importlib" markers = [ "azurite: tests that require Azurite (local Azure Storage emulator)", ] From 24012246ea27490a6ce7aed66e65cc9861678067 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 23:12:33 -0700 Subject: [PATCH 21/28] Remove repo-wide pytest importlib addopts Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 69843ba..145b930 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ include = ["durabletask", "durabletask.*"] minversion = "6.0" testpaths = ["tests"] asyncio_mode = "auto" -addopts = "--import-mode=importlib" markers = [ "azurite: tests that require Azurite (local Azure Storage emulator)", ] From bbaf2b08b4d6978d20c3f4f863e31c72370162f7 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 23:25:02 -0700 Subject: [PATCH 22/28] Update gRPC resiliency plan tracking Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../plans/2026-04-23-grpc-resiliency.md | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md index 36f17a0..a9dc226 100644 --- a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md +++ b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md @@ -31,7 +31,7 @@ - Modify: `durabletask/grpc_options.py` - Create: `tests/durabletask/test_grpc_resiliency.py` -- [ ] **Step 1: Write the failing option tests** +- [x] **Step 1: Write the failing option tests** ```python import pytest @@ -81,13 +81,13 @@ def test_client_resiliency_rejects_negative_cooldown(): GrpcClientResiliencyOptions(min_recreate_interval_seconds=-1.0) ``` -- [ ] **Step 2: Run the test to verify it fails** +- [x] **Step 2: Run the test to verify it fails** Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` Expected: FAIL with `ImportError` or `AttributeError` because the new option classes do not exist yet. -- [ ] **Step 3: Write the minimal implementation** +- [x] **Step 3: Write the minimal implementation** ```python from dataclasses import dataclass, field @@ -131,13 +131,13 @@ class GrpcClientResiliencyOptions: raise ValueError("min_recreate_interval_seconds must be >= 0") ``` -- [ ] **Step 4: Run the tests to verify they pass** +- [x] **Step 4: Run the tests to verify they pass** Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` Expected: PASS for the new option validation tests. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add durabletask/grpc_options.py tests/durabletask/test_grpc_resiliency.py @@ -154,7 +154,7 @@ git commit -m "Add gRPC resiliency option types" - Modify: `tests/durabletask/test_client.py` - Create: `tests/durabletask-azuremanaged/test_grpc_resiliency.py` -- [ ] **Step 1: Write the failing constructor and wrapper tests** +- [x] **Step 1: Write the failing constructor and wrapper tests** ```python from unittest.mock import MagicMock, patch @@ -225,13 +225,13 @@ def test_dts_worker_passes_resiliency_options_to_base_worker(): assert mock_init.call_args.kwargs["resiliency_options"] is resiliency ``` -- [ ] **Step 2: Run the tests to verify they fail** +- [x] **Step 2: Run the tests to verify they fail** Run: `python -m pytest tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v` Expected: FAIL because the constructors do not accept `resiliency_options` yet and do not retain enough transport state for later recreation. -- [ ] **Step 3: Write the minimal implementation** +- [x] **Step 3: Write the minimal implementation** ```python self._host_address = host_address if host_address else shared.get_default_host_address() @@ -279,13 +279,13 @@ super().__init__( ) ``` -- [ ] **Step 4: Run the tests to verify they pass** +- [x] **Step 4: Run the tests to verify they pass** Run: `python -m pytest tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v` Expected: PASS for the new constructor and wrapper pass-through tests. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py @@ -298,7 +298,7 @@ git commit -m "Thread gRPC resiliency options through constructors" - Create: `durabletask/internal/grpc_resiliency.py` - Modify: `tests/durabletask/test_grpc_resiliency.py` -- [ ] **Step 1: Write the failing helper tests** +- [x] **Step 1: Write the failing helper tests** ```python import grpc @@ -340,13 +340,13 @@ def test_worker_transport_failure_filters_application_errors(): assert is_worker_transport_failure(grpc.StatusCode.NOT_FOUND) is False ``` -- [ ] **Step 2: Run the tests to verify they fail** +- [x] **Step 2: Run the tests to verify they fail** Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -k "jitter or tracker or transport_failure" -v` Expected: FAIL because the shared helper module and helper functions do not exist yet. -- [ ] **Step 3: Write the minimal implementation** +- [x] **Step 3: Write the minimal implementation** ```python import random @@ -399,13 +399,13 @@ def is_worker_transport_failure(status_code: grpc.StatusCode) -> bool: } ``` -- [ ] **Step 4: Run the tests to verify they pass** +- [x] **Step 4: Run the tests to verify they pass** Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` Expected: PASS for the helper and option tests together. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add durabletask/internal/grpc_resiliency.py tests/durabletask/test_grpc_resiliency.py @@ -418,7 +418,7 @@ git commit -m "Add shared gRPC resiliency helpers" - Modify: `durabletask/worker.py` - Create: `tests/durabletask/test_worker_resiliency.py` -- [ ] **Step 1: Write the failing worker resiliency tests** +- [x] **Step 1: Write the failing worker resiliency tests** ```python import grpc @@ -461,13 +461,13 @@ def test_worker_does_not_recreate_caller_owned_channel(): assert worker._can_recreate_channel() is False ``` -- [ ] **Step 2: Run the tests to verify they fail** +- [x] **Step 2: Run the tests to verify they fail** Run: `python -m pytest tests/durabletask/test_worker_resiliency.py -v` Expected: FAIL because the worker does not expose explicit stream-outcome helpers yet and still uses ad hoc reconnect bookkeeping. -- [ ] **Step 3: Write the minimal implementation** +- [x] **Step 3: Write the minimal implementation** ```python class _WorkItemStreamOutcome(Enum): @@ -513,13 +513,13 @@ if work_item.HasField("healthPing"): continue ``` -- [ ] **Step 4: Run the worker tests** +- [x] **Step 4: Run the worker tests** Run: `python -m pytest tests/durabletask/test_worker_resiliency.py -v` Expected: PASS for the worker classification and ownership tests. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add durabletask/worker.py tests/durabletask/test_worker_resiliency.py @@ -532,7 +532,7 @@ git commit -m "Harden worker gRPC stream reconnect behavior" - Modify: `durabletask/client.py` - Modify: `tests/durabletask/test_client.py` -- [ ] **Step 1: Write the failing sync client recreation tests** +- [x] **Step 1: Write the failing sync client recreation tests** ```python import grpc @@ -589,13 +589,13 @@ def test_sync_client_does_not_count_long_poll_deadline(): assert client._client_failure_tracker.consecutive_failures == 0 ``` -- [ ] **Step 2: Run the tests to verify they fail** +- [x] **Step 2: Run the tests to verify they fail** Run: `python -m pytest tests/durabletask/test_client.py -k "recreates_sdk_owned_channel or long_poll_deadline" -v` Expected: FAIL because client calls still go directly through the stub and the client has no failure tracker or channel recreation path. -- [ ] **Step 3: Write the minimal implementation** +- [x] **Step 3: Write the minimal implementation** ```python self._client_failure_tracker = FailureTracker( @@ -643,13 +643,13 @@ def _maybe_recreate_channel(self) -> None: threading.Timer(30.0, old_channel.close).start() ``` -- [ ] **Step 4: Run the tests to verify they pass** +- [x] **Step 4: Run the tests to verify they pass** Run: `python -m pytest tests/durabletask/test_client.py -k "recreates_sdk_owned_channel or long_poll_deadline" -v` Expected: PASS for both new sync client tests and no regressions in the existing client construction tests. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add durabletask/client.py tests/durabletask/test_client.py @@ -662,7 +662,7 @@ git commit -m "Add sync client gRPC channel recreation" - Modify: `durabletask/client.py` - Modify: `tests/durabletask/test_client.py` -- [ ] **Step 1: Write the failing async client recreation tests** +- [x] **Step 1: Write the failing async client recreation tests** ```python import grpc @@ -716,13 +716,13 @@ async def test_async_client_does_not_count_wait_for_orchestration_deadline(): assert client._client_failure_tracker.consecutive_failures == 0 ``` -- [ ] **Step 2: Run the tests to verify they fail** +- [x] **Step 2: Run the tests to verify they fail** Run: `python -m pytest tests/durabletask/test_client.py -k "async_client_recreates_sdk_owned_channel or async_client_does_not_count" -v` Expected: FAIL because the async client still awaits stub methods directly and has no async-safe recreation path. -- [ ] **Step 3: Write the minimal implementation** +- [x] **Step 3: Write the minimal implementation** ```python self._client_failure_tracker = FailureTracker( @@ -775,13 +775,13 @@ async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None: await channel.close() ``` -- [ ] **Step 4: Run the tests to verify they pass** +- [x] **Step 4: Run the tests to verify they pass** Run: `python -m pytest tests/durabletask/test_client.py -k "async_client_recreates_sdk_owned_channel or async_client_does_not_count" -v` Expected: PASS for the async recreation tests and no regressions in the existing async client construction tests. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add durabletask/client.py tests/durabletask/test_client.py @@ -796,7 +796,7 @@ git commit -m "Add async client gRPC channel recreation" - Modify: `docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md` (only if the implementation changed the agreed design) - Modify: `docs/superpowers/plans/2026-04-23-grpc-resiliency.md` (check off completed steps only after execution) -- [ ] **Step 1: Add the changelog entries** +- [x] **Step 1: Add the changelog entries** ```markdown ## Unreleased @@ -814,7 +814,7 @@ git commit -m "Add async client gRPC channel recreation" - Added pass-through support for the new gRPC resiliency option types on Azure Managed clients and workers. ``` -- [ ] **Step 2: Run the focused tests** +- [x] **Step 2: Run the focused tests** Run: @@ -824,7 +824,7 @@ python -m pytest tests/durabletask/test_grpc_resiliency.py tests/durabletask/tes Expected: PASS for all new and touched unit tests. -- [ ] **Step 3: Run lint on the changed Python files** +- [x] **Step 3: Run lint on the changed Python files** Run: @@ -834,7 +834,7 @@ python -m flake8 durabletask/grpc_options.py durabletask/internal/grpc_resilienc Expected: no output -- [ ] **Step 4: Run the full test suite** +- [x] **Step 4: Run the full test suite** Run: @@ -844,7 +844,7 @@ python -m pytest Expected: PASS across the repository, including the existing orchestration and Azure Managed test suites. -- [ ] **Step 5: Commit** +- [x] **Step 5: Commit** ```bash git add CHANGELOG.md durabletask-azuremanaged/CHANGELOG.md durabletask/grpc_options.py durabletask/internal/grpc_resiliency.py durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py From 72191f1271f00cfc97adaffbc338860cbfc7964b Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Thu, 23 Apr 2026 23:49:09 -0700 Subject: [PATCH 23/28] Fix worker channel retirement for in-flight completions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 2 + durabletask/worker.py | 134 ++++++++++++++++++-- tests/durabletask/test_worker_resiliency.py | 123 ++++++++++++++++++ 3 files changed, 249 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80cb2cf..d8b6206 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ FIXED so configured hello timeouts apply on fresh connections, received work resets failure tracking, SDK-owned channels are refreshed and cleaned up safely, and caller-owned channels are never recreated or closed during reconnects. +- Fixed `TaskHubGrpcWorker` so in-flight work item completions can finish after + a graceful gRPC stream reset before the worker retires an SDK-owned channel. - Improved sync and async gRPC clients so repeated transport failures recreate SDK-owned channels, while long-poll deadlines, successful replies, and application-level RPC errors do not trigger unnecessary channel replacement. diff --git a/durabletask/worker.py b/durabletask/worker.py index 460e5f5..a9f9f4a 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from threading import Event, Thread +from threading import Event, Lock, Thread from types import GeneratorType from enum import Enum from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload @@ -130,6 +130,73 @@ class _WorkItemStreamOutcome(Enum): SILENT_DISCONNECT = "silent_disconnect" +@dataclass +class _TrackedChannelState: + channel: Any + ref_count: int = 0 + close_when_released: bool = False + + +class _InFlightChannelTracker: + def __init__(self): + self._lock = Lock() + self._states: dict[int, _TrackedChannelState] = {} + + def acquire(self, channel: Any): + channel_key = id(channel) + with self._lock: + state = self._states.get(channel_key) + if state is None: + state = _TrackedChannelState(channel=channel) + self._states[channel_key] = state + state.ref_count += 1 + + released = False + + def release() -> None: + nonlocal released + if released: + return + released = True + + channel_to_close = None + with self._lock: + state = self._states.get(channel_key) + if state is None: + return + + state.ref_count -= 1 + if state.ref_count == 0: + if state.close_when_released: + channel_to_close = state.channel + del self._states[channel_key] + + if channel_to_close is not None: + self._close_channel(channel_to_close) + + return release + + def retire(self, channel: Any) -> None: + channel_key = id(channel) + channel_to_close = None + with self._lock: + state = self._states.get(channel_key) + if state is None: + channel_to_close = channel + else: + state.close_when_released = True + + if channel_to_close is not None: + self._close_channel(channel_to_close) + + @staticmethod + def _close_channel(channel: Any) -> None: + try: + channel.close() + except Exception: + pass + + class VersioningOptions: """Configuration options for orchestrator and activity versioning. @@ -642,6 +709,7 @@ async def _async_run_loop(self): failure_tracker = FailureTracker( threshold=self._resiliency_options.channel_recreate_failure_threshold, ) + in_flight_channel_tracker = _InFlightChannelTracker() def get_reconnect_delay_seconds() -> float: return get_full_jitter_delay_seconds( @@ -671,6 +739,45 @@ def create_fresh_connection(): current_stub = None raise + def wrap_execution(handler, release): + def wrapped(*args, **kwargs): + result = handler(*args, **kwargs) + release() + return result + + return wrapped + + def wrap_cancellation(handler, release): + def wrapped(*args, **kwargs): + try: + return handler(*args, **kwargs) + finally: + release() + + return wrapped + + def submit_work_item( + submit_func, + handler, + cancellation_handler, + request, + stub, + completion_token, + channel, + ): + release = in_flight_channel_tracker.acquire(channel) + try: + submit_func( + wrap_execution(handler, release), + wrap_cancellation(cancellation_handler, release), + request, + stub, + completion_token, + ) + except Exception: + release() + raise + def invalidate_connection( *, recreate_channel: bool = False, @@ -700,10 +807,7 @@ def invalidate_connection( and self._can_recreate_channel() and (recreate_channel or close_channel) ): - try: - current_channel.close() - except Exception: - pass + in_flight_channel_tracker.retire(current_channel) current_channel = None current_stub = None @@ -742,7 +846,9 @@ def should_invalidate_connection(rpc_error): continue try: assert current_stub is not None + assert current_channel is not None stub = current_stub + channel = current_channel capabilities = [] if self._payload_store is not None: capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) @@ -822,36 +928,44 @@ def stream_reader(): failure_tracker.record_success() if work_item.HasField("orchestratorRequest"): - self._async_worker_manager.submit_orchestration( + submit_work_item( + self._async_worker_manager.submit_orchestration, self._execute_orchestrator, self._cancel_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken, + channel, ) elif work_item.HasField("activityRequest"): - self._async_worker_manager.submit_activity( + submit_work_item( + self._async_worker_manager.submit_activity, self._execute_activity, self._cancel_activity, work_item.activityRequest, stub, work_item.completionToken, + channel, ) elif work_item.HasField("entityRequest"): - self._async_worker_manager.submit_entity_batch( + submit_work_item( + self._async_worker_manager.submit_entity_batch, self._execute_entity_batch, self._cancel_entity_batch, work_item.entityRequest, stub, work_item.completionToken, + channel, ) elif work_item.HasField("entityRequestV2"): - self._async_worker_manager.submit_entity_batch( + submit_work_item( + self._async_worker_manager.submit_entity_batch, self._execute_entity_batch, self._cancel_entity_batch, work_item.entityRequestV2, stub, - work_item.completionToken + work_item.completionToken, + channel, ) else: self._logger.warning( diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py index 9d3f596..6e324ad 100644 --- a/tests/durabletask/test_worker_resiliency.py +++ b/tests/durabletask/test_worker_resiliency.py @@ -78,6 +78,16 @@ def shutdown(self): self._shutdown_event.set() +def _complete_activity_request(req, stub, completion_token): + stub.CompleteActivityTask( + pb.ActivityResponse( + instanceId=req.orchestrationInstance.instanceId, + taskId=req.taskId, + completionToken=completion_token, + ) + ) + + def _make_activity_work_item() -> pb.WorkItem: return pb.WorkItem( activityRequest=pb.ActivityRequest( @@ -331,6 +341,119 @@ def create_stub(channel): created_channels[1].close.assert_called_once() +@pytest.mark.asyncio +async def test_worker_defers_sdk_owned_channel_close_until_inflight_completion_finishes(monkeypatch): + worker = TaskHubGrpcWorker() + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + worker._execute_activity = _complete_activity_request + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + completed_responses = [] + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()]) + + def complete_activity(response): + assert created_channels[0].close.call_count == 0 + completed_responses.append(response) + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(worker_manager.submissions) == 1 + assert len(created_channels) == 2 + assert stub_channels == created_channels + created_channels[0].close.assert_not_called() + created_channels[1].close.assert_called_once() + + _, submission = worker_manager.submissions[0] + func, _, req, stub, completion_token = submission + func(req, stub, completion_token) + + assert len(completed_responses) == 1 + assert completed_responses[0].completionToken == "token" + created_channels[0].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_never_closes_caller_owned_channel_after_graceful_reset(monkeypatch): + provided_channel = MagicMock(name="provided-channel") + worker = TaskHubGrpcWorker(channel=provided_channel) + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + worker._execute_activity = _complete_activity_request + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + completed_responses = [] + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()]) + + def complete_activity(response): + assert provided_channel.close.call_count == 0 + completed_responses.append(response) + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr( + "durabletask.worker.shared.get_grpc_channel", + lambda *args, **kwargs: pytest.fail( + "SDK channel factory should not run for caller-owned channels" + ), + ) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(worker_manager.submissions) == 1 + assert stub_channels == [provided_channel, provided_channel] + provided_channel.close.assert_not_called() + + _, submission = worker_manager.submissions[0] + func, _, req, stub, completion_token = submission + func(req, stub, completion_token) + + assert len(completed_responses) == 1 + assert completed_responses[0].completionToken == "token" + provided_channel.close.assert_not_called() + + @pytest.mark.asyncio async def test_worker_uses_reconnect_backoff_helper_after_connection_failure(monkeypatch): worker = TaskHubGrpcWorker( From a49f5ce332f058bb0cafc841cc37532d06357a96 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 24 Apr 2026 00:57:07 -0700 Subject: [PATCH 24/28] Fix worker shutdown channel draining Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 5 +- durabletask/worker.py | 39 ++-- .../test_worker_concurrency_loop.py | 2 + .../test_worker_concurrency_loop_async.py | 1 + tests/durabletask/test_worker_resiliency.py | 204 +++++++++++++++++- 5 files changed, 231 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8b6206..a0a9818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,8 +33,9 @@ FIXED so configured hello timeouts apply on fresh connections, received work resets failure tracking, SDK-owned channels are refreshed and cleaned up safely, and caller-owned channels are never recreated or closed during reconnects. -- Fixed `TaskHubGrpcWorker` so in-flight work item completions can finish after - a graceful gRPC stream reset before the worker retires an SDK-owned channel. +- Fixed `TaskHubGrpcWorker` so in-flight and queued work item completions keep + draining across graceful gRPC stream resets and worker shutdown before the + worker retires an SDK-owned channel. - Improved sync and async gRPC clients so repeated transport failures recreate SDK-owned channels, while long-poll deadlines, successful replies, and application-level RPC errors do not trigger unnecessary channel replacement. diff --git a/durabletask/worker.py b/durabletask/worker.py index a9f9f4a..fa341e0 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -690,6 +690,8 @@ def start(self): if self._auto_generate_work_item_filters: self._work_item_filters = WorkItemFilters._from_registry(self._registry) + self._shutdown.clear() + def run_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -701,6 +703,7 @@ def run_loop(): self._is_running = True async def _async_run_loop(self): + self._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(self._async_worker_manager.run()) current_channel = self._channel current_stub = None @@ -1060,6 +1063,11 @@ def stop(self): self._response_stream.cancel() if self._runLoop is not None: self._runLoop.join(timeout=30) + if self._runLoop.is_alive(): + self._logger.info( + "Waiting for pending work items to finish before completing shutdown..." + ) + self._runLoop.join() self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False @@ -2883,11 +2891,22 @@ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logg self._pending_activity_work: list = [] self._pending_orchestration_work: list = [] self._pending_entity_batch_work: list = [] - self.thread_pool = ThreadPoolExecutor( - max_workers=concurrency_options.maximum_thread_pool_workers, + self.thread_pool = self._create_thread_pool() + self._shutdown = False + + def _create_thread_pool(self) -> ThreadPoolExecutor: + return ThreadPoolExecutor( + max_workers=self.concurrency_options.maximum_thread_pool_workers, thread_name_prefix="DurableTask", ) + + def _ensure_thread_pool(self) -> None: + if getattr(self.thread_pool, "_shutdown", False): + self.thread_pool = self._create_thread_pool() + + def prepare_for_run(self) -> None: self._shutdown = False + self._ensure_thread_pool() def _ensure_queues_for_current_loop(self): """Ensure queues are bound to the current event loop.""" @@ -2962,8 +2981,7 @@ def _ensure_queues_for_current_loop(self): self._pending_entity_batch_work.clear() async def run(self): - # Reset shutdown flag in case this manager is being reused - self._shutdown = False + self._ensure_thread_pool() # Ensure queues are properly bound to the current event loop self._ensure_queues_for_current_loop() @@ -3025,6 +3043,9 @@ async def run(self): except Exception as cancellation_exception: self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}") self.shutdown() + finally: + if not getattr(self.thread_pool, "_shutdown", False): + self.thread_pool.shutdown(wait=True) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): # List to track running tasks @@ -3068,12 +3089,7 @@ async def _run_func(self, func, *args, **kwargs): return await func(*args, **kwargs) else: loop = asyncio.get_running_loop() - # Avoid submitting to executor after shutdown - if ( - getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( - self.thread_pool, "_shutdown", False) - ): - return None + self._ensure_thread_pool() return await loop.run_in_executor( self.thread_pool, lambda: func(*args, **kwargs) ) @@ -3113,11 +3129,10 @@ def submit_entity_batch(self, func, cancellation_func, *args, **kwargs): def shutdown(self): self._shutdown = True - self.thread_pool.shutdown(wait=True) async def reset_for_new_run(self): """Reset the manager state for a new run.""" - self._shutdown = False + self.prepare_for_run() # Clear any existing queues - they'll be recreated when needed if self.activity_queue is not None: # Clear existing queue by creating a new one diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index 6fd1270..199c219 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -73,6 +73,7 @@ def cancel_dummy_activity(req, stub, completionToken): async def run_test(): # Start the worker manager's run loop in the background + worker._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(worker._async_worker_manager.run()) for req in orchestrator_requests: worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken()) @@ -133,6 +134,7 @@ def fn(*args, **kwargs): # Run the manager loop in a thread (sync context) def run_manager(): + manager.prepare_for_run() asyncio.run(manager.run()) t = threading.Thread(target=run_manager) diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index 8482c20..22ffe23 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -72,6 +72,7 @@ async def cancel_dummy_activity(req, stub, completionToken): async def run_test(): # Clear stub state before each run stub.completed.clear() + grpc_worker._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) # Need to yield to that thread in order to let it start up on the second run startup_attempts = 0 diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py index 6e324ad..0623847 100644 --- a/tests/durabletask/test_worker_resiliency.py +++ b/tests/durabletask/test_worker_resiliency.py @@ -1,13 +1,18 @@ import asyncio import grpc -from threading import Event +from threading import Event, Timer from unittest.mock import MagicMock import pytest from durabletask.grpc_options import GrpcWorkerResiliencyOptions from durabletask.internal import orchestrator_service_pb2 as pb -from durabletask.worker import TaskHubGrpcWorker, _WorkItemStreamOutcome +from durabletask.worker import ( + _AsyncWorkerManager, + ConcurrencyOptions, + TaskHubGrpcWorker, + _WorkItemStreamOutcome, +) class FakeRpcError(grpc.RpcError): @@ -62,6 +67,9 @@ def __init__(self): self._shutdown_event = asyncio.Event() self.submissions: list[tuple[str, tuple]] = [] + def prepare_for_run(self): + self._shutdown_event = asyncio.Event() + async def run(self): await self._shutdown_event.wait() @@ -88,16 +96,58 @@ def _complete_activity_request(req, stub, completion_token): ) -def _make_activity_work_item() -> pb.WorkItem: +def _make_activity_work_item( + task_id: int = 1, + completion_token: str = "token", + instance_id: str = "instance-id", +) -> pb.WorkItem: return pb.WorkItem( activityRequest=pb.ActivityRequest( name="test_activity", - taskId=1, - orchestrationInstance=pb.OrchestrationInstance(instanceId="instance-id"), + taskId=task_id, + orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id), ), - completionToken="token", + completionToken=completion_token, + ) + + +async def _wait_for_condition(predicate, *, timeout: float = 2.0): + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while not predicate(): + if loop.time() >= deadline: + raise AssertionError("condition was not met before timeout") + await asyncio.sleep(0.01) + + +@pytest.mark.asyncio +async def test_async_worker_manager_honors_shutdown_requested_before_run(): + manager = _AsyncWorkerManager( + ConcurrencyOptions(maximum_thread_pool_workers=1), + MagicMock(), ) + manager.shutdown() + await asyncio.wait_for(manager.run(), timeout=1.0) + + +def test_worker_start_clears_prior_shutdown_request(): + worker = TaskHubGrpcWorker() + worker._shutdown.set() + run_started = Event() + + async def fake_run_loop(): + run_started.set() + + worker._async_run_loop = fake_run_loop + worker.start() + worker._runLoop.join(timeout=1.0) + + assert run_started.is_set() is True + assert worker._shutdown.is_set() is False + + worker.stop() + def test_worker_classifies_graceful_close_before_first_message(): worker = TaskHubGrpcWorker( @@ -399,6 +449,148 @@ def create_stub(channel): created_channels[0].close.assert_called_once() +@pytest.mark.asyncio +async def test_worker_shutdown_drains_real_manager_work_before_closing_retired_sdk_channel(monkeypatch): + worker = TaskHubGrpcWorker( + concurrency_options=ConcurrencyOptions( + maximum_concurrent_activity_work_items=1, + maximum_thread_pool_workers=1, + ) + ) + worker._execute_activity = _complete_activity_request + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + allow_first_completion = Event() + first_completion_started = Event() + completed_task_ids = [] + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[ + _make_activity_work_item(task_id=1, completion_token="token-1"), + _make_activity_work_item(task_id=2, completion_token="token-2"), + ]) + + def complete_activity(response): + completed_task_ids.append(response.taskId) + if response.taskId == 1: + first_completion_started.set() + Timer(0.2, allow_first_completion.set).start() + assert allow_first_completion.wait(timeout=5.0) + elif response.taskId == 2: + assert created_channels[0].close.call_count == 0 + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + run_task = asyncio.create_task(worker._async_run_loop()) + await asyncio.wait_for(run_task, timeout=2.0) + + assert first_completion_started.is_set() is True + assert len(created_channels) == 2 + assert stub_channels == created_channels + assert completed_task_ids == [1, 2] + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_shutdown_runs_real_manager_cancellation_wrapper_before_closing_retired_sdk_channel(monkeypatch): + worker = TaskHubGrpcWorker( + concurrency_options=ConcurrencyOptions( + maximum_concurrent_activity_work_items=1, + maximum_thread_pool_workers=1, + ) + ) + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + allow_first_completion = Event() + first_completion_started = Event() + completed_task_ids = [] + cancelled_task_ids = [] + + def execute_activity(req, stub, completion_token): + if req.taskId == 1: + _complete_activity_request(req, stub, completion_token) + else: + raise RuntimeError("boom") + + def cancel_activity(req, stub, completion_token): + cancelled_task_ids.append(req.taskId) + assert created_channels[0].close.call_count == 0 + + worker._execute_activity = execute_activity + worker._cancel_activity = cancel_activity + + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[ + _make_activity_work_item(task_id=1, completion_token="token-1"), + _make_activity_work_item(task_id=2, completion_token="token-2"), + ]) + + def complete_activity(response): + completed_task_ids.append(response.taskId) + Timer(0.2, allow_first_completion.set).start() + first_completion_started.set() + assert allow_first_completion.wait(timeout=5.0) + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + run_task = asyncio.create_task(worker._async_run_loop()) + await asyncio.wait_for(run_task, timeout=2.0) + + assert first_completion_started.is_set() is True + assert len(created_channels) == 2 + assert stub_channels == created_channels + assert completed_task_ids == [1] + assert cancelled_task_ids == [2] + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + @pytest.mark.asyncio async def test_worker_never_closes_caller_owned_channel_after_graceful_reset(monkeypatch): provided_channel = MagicMock(name="provided-channel") From 32f383d46cfe5944f592764106f6ee2f89bdcdf5 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 24 Apr 2026 01:14:15 -0700 Subject: [PATCH 25/28] Rename Azure Managed gRPC resiliency test module Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ...st_grpc_resiliency.py => test_azuremanaged_grpc_resiliency.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/durabletask-azuremanaged/{test_grpc_resiliency.py => test_azuremanaged_grpc_resiliency.py} (100%) diff --git a/tests/durabletask-azuremanaged/test_grpc_resiliency.py b/tests/durabletask-azuremanaged/test_azuremanaged_grpc_resiliency.py similarity index 100% rename from tests/durabletask-azuremanaged/test_grpc_resiliency.py rename to tests/durabletask-azuremanaged/test_azuremanaged_grpc_resiliency.py From b6f3b43cc6e8c2a5c2cf4998d2845f0388d30f44 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 24 Apr 2026 01:58:04 -0700 Subject: [PATCH 26/28] Fix sync client channel cleanup Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 3 + durabletask/client.py | 30 +++++++++- tests/durabletask/test_client.py | 96 ++++++++++++++++++++++++++++++-- 3 files changed, 120 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0a9818..695c4c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ FIXED - Improved sync and async gRPC clients so repeated transport failures recreate SDK-owned channels, while long-poll deadlines, successful replies, and application-level RPC errors do not trigger unnecessary channel replacement. +- Fixed `TaskHubGrpcClient.close()` so explicit sync client shutdown now closes + any previously retired SDK-owned gRPC channels immediately instead of waiting + for the delayed cleanup timer. ## v1.4.0 diff --git a/durabletask/client.py b/durabletask/client.py index 1648478..3960802 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -211,8 +211,10 @@ def __init__(self, *, self._client_failure_tracker = FailureTracker( self._resiliency_options.channel_recreate_failure_threshold ) + self._closing = False self._last_recreate_time = 0.0 self._recreate_lock = threading.Lock() + self._retired_channels: dict[grpc.Channel, threading.Timer] = {} self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store @@ -243,9 +245,11 @@ def _invoke_unary( return response def _maybe_recreate_channel(self) -> None: - if not self._owns_channel: + if not self._owns_channel or self._closing: return with self._recreate_lock: + if self._closing: + return now = time.monotonic() if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: return @@ -259,10 +263,22 @@ def _maybe_recreate_channel(self) -> None: self._stub = stubs.TaskHubSidecarServiceStub(self._channel) self._last_recreate_time = now self._client_failure_tracker.record_success() - close_timer = threading.Timer(30.0, old_channel.close) + close_timer = threading.Timer( + 30.0, + self._close_retired_channel, + args=(old_channel,), + ) close_timer.daemon = True + self._retired_channels[old_channel] = close_timer close_timer.start() + def _close_retired_channel(self, channel: grpc.Channel) -> None: + with self._recreate_lock: + close_timer = self._retired_channels.pop(channel, None) + if close_timer is None: + return + channel.close() + def close(self) -> None: """Closes the underlying gRPC channel. @@ -272,7 +288,15 @@ def close(self) -> None: it. """ if self._owns_channel: - self._channel.close() + with self._recreate_lock: + self._closing = True + retired_channels = list(self._retired_channels.items()) + self._retired_channels.clear() + current_channel = self._channel + for retired_channel, close_timer in retired_channels: + close_timer.cancel() + retired_channel.close() + current_channel.close() def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 724e984..adc31b4 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -382,11 +382,88 @@ def test_sync_client_recreates_sdk_owned_channel_with_original_transport_inputs( expected_channel_call, ] assert client._channel is second_channel - mock_timer.assert_called_once_with(30.0, first_channel.close) + mock_timer.assert_called_once() + timer_call = mock_timer.call_args + assert timer_call.args[0] == 30.0 + assert timer_call.args[1].__self__ is client + assert timer_call.args[1].__func__ is TaskHubGrpcClient._close_retired_channel + assert timer_call.kwargs == {"args": (first_channel,)} assert timer.daemon is True timer.start.assert_called_once_with() +def test_sync_client_close_closes_retired_channels_immediately(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + first_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub = MagicMock() + second_stub.GetInstance.return_value = MagicMock(exists=False) + close_timer = MagicMock(name="close-timer") + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel], + ), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ), patch("threading.Timer", return_value=close_timer): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + + client.close() + + close_timer.cancel.assert_called_once_with() + first_channel.close.assert_called_once_with() + second_channel.close.assert_called_once_with() + assert client._retired_channels == {} + + +def test_sync_client_close_closes_all_retired_sdk_channels_immediately(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + third_channel = MagicMock(name="third-channel") + first_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub = MagicMock() + second_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + third_stub = MagicMock() + timer1 = MagicMock(name="close-timer-1") + timer2 = MagicMock(name="close-timer-2") + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel, third_channel], + ), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", + side_effect=[first_stub, second_stub, third_stub], + ), patch("threading.Timer", side_effect=[timer1, timer2]): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + + client.close() + + timer1.cancel.assert_called_once_with() + timer2.cancel.assert_called_once_with() + first_channel.close.assert_called_once_with() + second_channel.close.assert_called_once_with() + third_channel.close.assert_called_once_with() + assert client._retired_channels == {} + + @pytest.mark.parametrize( ("stub_method_name", "client_method_name"), [ @@ -431,11 +508,13 @@ def test_sync_client_does_not_recreate_caller_owned_channel(): client.get_orchestration_state("abc") with pytest.raises(FakeRpcError): client.get_orchestration_state("abc") + client.close() assert client._channel is provided_channel mock_get_channel.assert_not_called() mock_stub.assert_called_once_with(provided_channel) mock_timer.assert_not_called() + provided_channel.close.assert_not_called() def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation(): @@ -478,7 +557,6 @@ def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation(): client.get_orchestration_state("abc") assert client._channel is second_channel assert mock_get_channel.call_count == 2 - mock_timer.assert_called_once_with(30.0, first_channel.close) with pytest.raises(FakeRpcError): client.get_orchestration_state("abc") @@ -495,10 +573,16 @@ def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation(): expected_channel_call, expected_channel_call, ] - assert mock_timer.call_args_list == [ - call(30.0, first_channel.close), - call(30.0, second_channel.close), - ] + assert mock_timer.call_count == 2 + first_timer_call, second_timer_call = mock_timer.call_args_list + assert first_timer_call.args[0] == 30.0 + assert first_timer_call.args[1].__self__ is client + assert first_timer_call.args[1].__func__ is TaskHubGrpcClient._close_retired_channel + assert first_timer_call.kwargs == {"args": (first_channel,)} + assert second_timer_call.args[0] == 30.0 + assert second_timer_call.args[1].__self__ is client + assert second_timer_call.args[1].__func__ is TaskHubGrpcClient._close_retired_channel + assert second_timer_call.kwargs == {"args": (second_channel,)} assert timer1.daemon is True assert timer2.daemon is True timer1.start.assert_called_once_with() From 1f9545f7ba9244b094976231eb6dd379daa739af Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 24 Apr 2026 10:24:37 -0700 Subject: [PATCH 27/28] Address automated review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../plans/2026-04-23-grpc-resiliency.md | 6 ++- durabletask/client.py | 19 ++++--- durabletask/internal/grpc_resiliency.py | 3 +- durabletask/worker.py | 9 ++-- tests/durabletask/test_client.py | 4 +- tests/durabletask/test_worker_resiliency.py | 50 +++++++++++++++++++ 6 files changed, 73 insertions(+), 18 deletions(-) diff --git a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md index a9dc226..67042c1 100644 --- a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md +++ b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md @@ -1,6 +1,10 @@ # gRPC Resiliency Implementation Plan -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. +> [!NOTE] +> For agentic workers: REQUIRED SUB-SKILL: +> Use superpowers:subagent-driven-development (recommended) or +> superpowers:executing-plans to implement this plan task-by-task. +> Steps use checkbox (`- [ ]`) syntax for tracking. **Goal:** Implement automatic healing of stale gRPC worker streams and client channels in `durabletask-python`, aligned with the behavior added in `durabletask-dotnet` PR 708. diff --git a/durabletask/client.py b/durabletask/client.py index 3960802..e76cdfa 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -224,7 +224,7 @@ def _invoke_unary( method_name: str, request: Any, *, - timeout: Optional[int] = None): + timeout: Optional[float] = None): method = getattr(self._stub, method_name) try: if timeout is None: @@ -406,7 +406,7 @@ def get_all_orchestration_states(self, def wait_for_orchestration_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") @@ -427,7 +427,7 @@ def wait_for_orchestration_start(self, instance_id: str, *, def wait_for_orchestration_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") @@ -685,7 +685,7 @@ async def _invoke_unary( method_name: str, request: Any, *, - timeout: Optional[int] = None): + timeout: Optional[float] = None): method = getattr(self._stub, method_name) try: if timeout is None: @@ -733,10 +733,9 @@ async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None: await asyncio.sleep(30.0) await channel.close() finally: - try: - self._retired_channels.remove(channel) - except ValueError: - pass + async with self._recreate_lock: + if channel in self._retired_channels: + self._retired_channels.remove(channel) async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, @@ -843,7 +842,7 @@ async def get_all_orchestration_states(self, async def wait_for_orchestration_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") @@ -863,7 +862,7 @@ async def wait_for_orchestration_start(self, instance_id: str, *, async def wait_for_orchestration_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") diff --git a/durabletask/internal/grpc_resiliency.py b/durabletask/internal/grpc_resiliency.py index 4845523..0a8cdd6 100644 --- a/durabletask/internal/grpc_resiliency.py +++ b/durabletask/internal/grpc_resiliency.py @@ -13,7 +13,8 @@ def get_full_jitter_delay_seconds( attempt: int, *, base_seconds: float, - cap_seconds: float) -> float: + cap_seconds: float, +) -> float: capped_attempt = min(attempt, 30) upper_bound = min(cap_seconds, base_seconds * (2 ** capped_attempt)) return random.random() * upper_bound diff --git a/durabletask/worker.py b/durabletask/worker.py index fa341e0..172a04d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -194,7 +194,7 @@ def _close_channel(channel: Any) -> None: try: channel.close() except Exception: - pass + logging.debug("Ignoring channel close failure during worker cleanup.", exc_info=True) class VersioningOptions: @@ -744,9 +744,10 @@ def create_fresh_connection(): def wrap_execution(handler, release): def wrapped(*args, **kwargs): - result = handler(*args, **kwargs) - release() - return result + try: + return handler(*args, **kwargs) + finally: + release() return wrapped diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index adc31b4..78cd28f 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -801,8 +801,8 @@ async def test_async_client_close_prevents_channel_recreation_race(): client._recreate_lock.release() with pytest.raises(grpc.aio.AioRpcError): - await rpc_task - await close_task + _ = await rpc_task + _ = await close_task assert mock_get_channel.call_count == 1 first_channel.close.assert_awaited_once() diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py index 0623847..8ad64cd 100644 --- a/tests/durabletask/test_worker_resiliency.py +++ b/tests/durabletask/test_worker_resiliency.py @@ -449,6 +449,56 @@ def create_stub(channel): created_channels[0].close.assert_called_once() +@pytest.mark.asyncio +async def test_worker_releases_inflight_channel_when_activity_handler_raises(monkeypatch): + worker = TaskHubGrpcWorker() + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + def fail_activity(req, stub, completion_token): + raise RuntimeError("boom") + + worker._execute_activity = fail_activity + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()]) + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + + def create_stub(channel): + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(worker_manager.submissions) == 1 + created_channels[0].close.assert_not_called() + + _, submission = worker_manager.submissions[0] + func, _, req, stub, completion_token = submission + with pytest.raises(RuntimeError, match="boom"): + func(req, stub, completion_token) + + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + @pytest.mark.asyncio async def test_worker_shutdown_drains_real_manager_work_before_closing_retired_sdk_channel(monkeypatch): worker = TaskHubGrpcWorker( From bdb37e22f02c8d301ce71509ff23389b743f27b4 Mon Sep 17 00:00:00 2001 From: Bernd Verst Date: Fri, 24 Apr 2026 11:49:10 -0700 Subject: [PATCH 28/28] Remove superpowers docs from PR Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .gitignore | 1 + .../plans/2026-04-23-grpc-resiliency.md | 856 ------------------ .../2026-04-23-grpc-resiliency-design.md | 325 ------- 3 files changed, 1 insertion(+), 1181 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-23-grpc-resiliency.md delete mode 100644 docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md diff --git a/.gitignore b/.gitignore index d789df9..8540077 100644 --- a/.gitignore +++ b/.gitignore @@ -132,5 +132,6 @@ dmypy.json # IDEs .idea .worktrees/ +docs/superpowers/ coverage.lcov diff --git a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md b/docs/superpowers/plans/2026-04-23-grpc-resiliency.md deleted file mode 100644 index 67042c1..0000000 --- a/docs/superpowers/plans/2026-04-23-grpc-resiliency.md +++ /dev/null @@ -1,856 +0,0 @@ -# gRPC Resiliency Implementation Plan - -> [!NOTE] -> For agentic workers: REQUIRED SUB-SKILL: -> Use superpowers:subagent-driven-development (recommended) or -> superpowers:executing-plans to implement this plan task-by-task. -> Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Implement automatic healing of stale gRPC worker streams and client channels in `durabletask-python`, aligned with the behavior added in `durabletask-dotnet` PR 708. - -**Architecture:** Add explicit public resiliency option types plus a shared internal transport helper module, then wire those pieces into the worker loop and the sync and async clients. SDK-owned channels will be recreated after repeated transport failures, while caller-owned channels keep their existing ownership model and are only observed and logged. - -**Tech Stack:** Python 3.10+, grpc, grpc.aio, pytest, unittest.mock, flake8 - ---- - -## File map - -- `durabletask/grpc_options.py` - public resiliency option dataclasses and validation -- `durabletask/internal/grpc_resiliency.py` - shared backoff, failure tracking, and transport-failure classification helpers -- `durabletask/client.py` - sync and async client transport state, unary invocation helpers, and channel recreation logic -- `durabletask/worker.py` - hello timeout, stream-outcome classification, worker reconnect policy, and SDK-owned channel recreation -- `durabletask-azuremanaged/durabletask/azuremanaged/client.py` - pass `resiliency_options` through Azure Managed client wrappers -- `durabletask-azuremanaged/durabletask/azuremanaged/worker.py` - pass `resiliency_options` through Azure Managed worker wrapper -- `tests/durabletask/test_grpc_resiliency.py` - option validation and shared helper tests -- `tests/durabletask/test_worker_resiliency.py` - worker stream monitoring and reconnect behavior -- `tests/durabletask/test_client.py` - sync and async client constructor and channel recreation tests -- `tests/durabletask-azuremanaged/test_grpc_resiliency.py` - wrapper pass-through tests for the new option surfaces -- `CHANGELOG.md` - user-facing changelog entry for the core SDK -- `durabletask-azuremanaged/CHANGELOG.md` - user-facing changelog entry for Azure Managed wrappers - -### Task 1: Add public resiliency option types - -**Files:** -- Modify: `durabletask/grpc_options.py` -- Create: `tests/durabletask/test_grpc_resiliency.py` - -- [x] **Step 1: Write the failing option tests** - -```python -import pytest - -from durabletask.grpc_options import ( - GrpcClientResiliencyOptions, - GrpcWorkerResiliencyOptions, -) - - -def test_worker_resiliency_defaults_are_enabled(): - options = GrpcWorkerResiliencyOptions() - assert options.hello_timeout_seconds == 30.0 - assert options.silent_disconnect_timeout_seconds == 120.0 - assert options.channel_recreate_failure_threshold == 5 - assert options.reconnect_backoff_base_seconds == 1.0 - assert options.reconnect_backoff_cap_seconds == 30.0 - - -def test_worker_resiliency_allows_disabling_timeout_and_threshold(): - options = GrpcWorkerResiliencyOptions( - silent_disconnect_timeout_seconds=0.0, - channel_recreate_failure_threshold=0, - ) - assert options.silent_disconnect_timeout_seconds == 0.0 - assert options.channel_recreate_failure_threshold == 0 - - -def test_worker_resiliency_rejects_invalid_durations(): - with pytest.raises(ValueError, match="hello_timeout_seconds must be > 0"): - GrpcWorkerResiliencyOptions(hello_timeout_seconds=0.0) - with pytest.raises(ValueError, match="reconnect_backoff_cap_seconds must be >= reconnect_backoff_base_seconds"): - GrpcWorkerResiliencyOptions( - reconnect_backoff_base_seconds=5.0, - reconnect_backoff_cap_seconds=1.0, - ) - - -def test_client_resiliency_defaults_are_enabled(): - options = GrpcClientResiliencyOptions() - assert options.channel_recreate_failure_threshold == 5 - assert options.min_recreate_interval_seconds == 30.0 - - -def test_client_resiliency_rejects_negative_cooldown(): - with pytest.raises(ValueError, match="min_recreate_interval_seconds must be >= 0"): - GrpcClientResiliencyOptions(min_recreate_interval_seconds=-1.0) -``` - -- [x] **Step 2: Run the test to verify it fails** - -Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` - -Expected: FAIL with `ImportError` or `AttributeError` because the new option classes do not exist yet. - -- [x] **Step 3: Write the minimal implementation** - -```python -from dataclasses import dataclass, field -from typing import Any, Optional - - -@dataclass -class GrpcWorkerResiliencyOptions: - hello_timeout_seconds: float = 30.0 - silent_disconnect_timeout_seconds: float = 120.0 - channel_recreate_failure_threshold: int = 5 - reconnect_backoff_base_seconds: float = 1.0 - reconnect_backoff_cap_seconds: float = 30.0 - - def __post_init__(self) -> None: - if self.hello_timeout_seconds <= 0: - raise ValueError("hello_timeout_seconds must be > 0") - if self.silent_disconnect_timeout_seconds < 0: - raise ValueError("silent_disconnect_timeout_seconds must be >= 0") - if self.channel_recreate_failure_threshold < 0: - raise ValueError("channel_recreate_failure_threshold must be >= 0") - if self.reconnect_backoff_base_seconds <= 0: - raise ValueError("reconnect_backoff_base_seconds must be > 0") - if self.reconnect_backoff_cap_seconds <= 0: - raise ValueError("reconnect_backoff_cap_seconds must be > 0") - if self.reconnect_backoff_cap_seconds < self.reconnect_backoff_base_seconds: - raise ValueError( - "reconnect_backoff_cap_seconds must be >= reconnect_backoff_base_seconds" - ) - - -@dataclass -class GrpcClientResiliencyOptions: - channel_recreate_failure_threshold: int = 5 - min_recreate_interval_seconds: float = 30.0 - - def __post_init__(self) -> None: - if self.channel_recreate_failure_threshold < 0: - raise ValueError("channel_recreate_failure_threshold must be >= 0") - if self.min_recreate_interval_seconds < 0: - raise ValueError("min_recreate_interval_seconds must be >= 0") -``` - -- [x] **Step 4: Run the tests to verify they pass** - -Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` - -Expected: PASS for the new option validation tests. - -- [x] **Step 5: Commit** - -```bash -git add durabletask/grpc_options.py tests/durabletask/test_grpc_resiliency.py -git commit -m "Add gRPC resiliency option types" -``` - -### Task 2: Thread resiliency options through constructors - -**Files:** -- Modify: `durabletask/client.py` -- Modify: `durabletask/worker.py` -- Modify: `durabletask-azuremanaged/durabletask/azuremanaged/client.py` -- Modify: `durabletask-azuremanaged/durabletask/azuremanaged/worker.py` -- Modify: `tests/durabletask/test_client.py` -- Create: `tests/durabletask-azuremanaged/test_grpc_resiliency.py` - -- [x] **Step 1: Write the failing constructor and wrapper tests** - -```python -from unittest.mock import MagicMock, patch - -from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient -from durabletask.grpc_options import ( - GrpcClientResiliencyOptions, - GrpcWorkerResiliencyOptions, -) -from durabletask.worker import TaskHubGrpcWorker -from durabletask.azuremanaged.client import DurableTaskSchedulerClient -from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker - - -def test_client_stores_resiliency_options_for_recreation(): - resiliency = GrpcClientResiliencyOptions(channel_recreate_failure_threshold=7) - with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() - ): - client = TaskHubGrpcClient( - host_address="localhost:4001", - resiliency_options=resiliency, - ) - assert client._resiliency_options is resiliency - assert client._host_address == "localhost:4001" - - -def test_async_client_stores_resolved_transport_inputs(): - resiliency = GrpcClientResiliencyOptions() - with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() - ): - client = AsyncTaskHubGrpcClient( - host_address="localhost:4001", - resiliency_options=resiliency, - ) - assert client._resiliency_options is resiliency - assert client._host_address == "localhost:4001" - - -def test_worker_stores_resiliency_options(): - resiliency = GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=9) - worker = TaskHubGrpcWorker(resiliency_options=resiliency) - assert worker._resiliency_options is resiliency - - -def test_dts_client_passes_resiliency_options_to_base_client(): - resiliency = GrpcClientResiliencyOptions() - with patch("durabletask.azuremanaged.client.TaskHubGrpcClient.__init__", return_value=None) as mock_init: - DurableTaskSchedulerClient( - host_address="localhost:4001", - taskhub="hub", - token_credential=None, - resiliency_options=resiliency, - ) - assert mock_init.call_args.kwargs["resiliency_options"] is resiliency - - -def test_dts_worker_passes_resiliency_options_to_base_worker(): - resiliency = GrpcWorkerResiliencyOptions() - with patch("durabletask.azuremanaged.worker.TaskHubGrpcWorker.__init__", return_value=None) as mock_init: - DurableTaskSchedulerWorker( - host_address="localhost:4001", - taskhub="hub", - token_credential=None, - resiliency_options=resiliency, - ) - assert mock_init.call_args.kwargs["resiliency_options"] is resiliency -``` - -- [x] **Step 2: Run the tests to verify they fail** - -Run: `python -m pytest tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v` - -Expected: FAIL because the constructors do not accept `resiliency_options` yet and do not retain enough transport state for later recreation. - -- [x] **Step 3: Write the minimal implementation** - -```python -self._host_address = host_address if host_address else shared.get_default_host_address() -self._secure_channel = secure_channel -self._channel_options = channel_options -self._resiliency_options = ( - resiliency_options if resiliency_options is not None else GrpcClientResiliencyOptions() -) -resolved_interceptors = ( - prepare_sync_interceptors(metadata, interceptors) if channel is None else interceptors -) -self._interceptors = list(resolved_interceptors) if resolved_interceptors is not None else None - -self._resiliency_options = ( - resiliency_options if resiliency_options is not None else GrpcWorkerResiliencyOptions() -) - -super().__init__( - host_address=host_address, - channel=channel, - secure_channel=secure_channel, - metadata=None, - log_handler=log_handler, - log_formatter=log_formatter, - interceptors=resolved_interceptors, - channel_options=channel_options, - resiliency_options=resiliency_options, - default_version=default_version, - payload_store=payload_store, -) - -super().__init__( - host_address=host_address, - channel=channel, - secure_channel=secure_channel, - metadata=None, - log_handler=log_handler, - log_formatter=log_formatter, - interceptors=resolved_interceptors, - channel_options=channel_options, - resiliency_options=resiliency_options, - concurrency_options=concurrency_options, - maximum_timer_interval=None, - payload_store=payload_store, -) -``` - -- [x] **Step 4: Run the tests to verify they pass** - -Run: `python -m pytest tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v` - -Expected: PASS for the new constructor and wrapper pass-through tests. - -- [x] **Step 5: Commit** - -```bash -git add durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -git commit -m "Thread gRPC resiliency options through constructors" -``` - -### Task 3: Add shared internal resiliency helpers - -**Files:** -- Create: `durabletask/internal/grpc_resiliency.py` -- Modify: `tests/durabletask/test_grpc_resiliency.py` - -- [x] **Step 1: Write the failing helper tests** - -```python -import grpc -import pytest - -from durabletask.internal.grpc_resiliency import ( - FailureTracker, - get_full_jitter_delay_seconds, - is_client_transport_failure, - is_worker_transport_failure, -) - - -def test_full_jitter_delay_is_capped(monkeypatch): - monkeypatch.setattr("durabletask.internal.grpc_resiliency.random.random", lambda: 1.0) - delay = get_full_jitter_delay_seconds(10, base_seconds=1.0, cap_seconds=30.0) - assert delay == 30.0 - - -def test_failure_tracker_trips_at_threshold(): - tracker = FailureTracker(threshold=3) - assert tracker.record_failure() is False - assert tracker.record_failure() is False - assert tracker.record_failure() is True - tracker.record_success() - assert tracker.consecutive_failures == 0 - - -def test_client_transport_failure_ignores_long_poll_deadlines(): - assert is_client_transport_failure("WaitForInstanceStart", grpc.StatusCode.DEADLINE_EXCEEDED) is False - assert is_client_transport_failure("StartInstance", grpc.StatusCode.DEADLINE_EXCEEDED) is True - assert is_client_transport_failure("GetInstance", grpc.StatusCode.UNAVAILABLE) is True - - -def test_worker_transport_failure_filters_application_errors(): - assert is_worker_transport_failure(grpc.StatusCode.UNAVAILABLE) is True - assert is_worker_transport_failure(grpc.StatusCode.DEADLINE_EXCEEDED) is True - assert is_worker_transport_failure(grpc.StatusCode.UNAUTHENTICATED) is False - assert is_worker_transport_failure(grpc.StatusCode.NOT_FOUND) is False -``` - -- [x] **Step 2: Run the tests to verify they fail** - -Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -k "jitter or tracker or transport_failure" -v` - -Expected: FAIL because the shared helper module and helper functions do not exist yet. - -- [x] **Step 3: Write the minimal implementation** - -```python -import random -from dataclasses import dataclass - -import grpc - - -LONG_POLL_METHODS = {"WaitForInstanceStart", "WaitForInstanceCompletion"} - - -def get_full_jitter_delay_seconds( - attempt: int, - *, - base_seconds: float, - cap_seconds: float, -) -> float: - capped_attempt = min(attempt, 30) - upper_bound = min(cap_seconds, base_seconds * (2 ** capped_attempt)) - return random.random() * upper_bound - - -@dataclass -class FailureTracker: - threshold: int - consecutive_failures: int = 0 - - def record_failure(self) -> bool: - if self.threshold <= 0: - return False - self.consecutive_failures += 1 - return self.consecutive_failures >= self.threshold - - def record_success(self) -> None: - self.consecutive_failures = 0 - - -def is_client_transport_failure(method_name: str, status_code: grpc.StatusCode) -> bool: - if status_code == grpc.StatusCode.UNAVAILABLE: - return True - if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: - return method_name not in LONG_POLL_METHODS - return False - - -def is_worker_transport_failure(status_code: grpc.StatusCode) -> bool: - return status_code in { - grpc.StatusCode.UNAVAILABLE, - grpc.StatusCode.DEADLINE_EXCEEDED, - } -``` - -- [x] **Step 4: Run the tests to verify they pass** - -Run: `python -m pytest tests/durabletask/test_grpc_resiliency.py -v` - -Expected: PASS for the helper and option tests together. - -- [x] **Step 5: Commit** - -```bash -git add durabletask/internal/grpc_resiliency.py tests/durabletask/test_grpc_resiliency.py -git commit -m "Add shared gRPC resiliency helpers" -``` - -### Task 4: Harden the worker stream lifecycle - -**Files:** -- Modify: `durabletask/worker.py` -- Create: `tests/durabletask/test_worker_resiliency.py` - -- [x] **Step 1: Write the failing worker resiliency tests** - -```python -import grpc -from unittest.mock import MagicMock - -from durabletask.grpc_options import GrpcWorkerResiliencyOptions -from durabletask.worker import TaskHubGrpcWorker, _WorkItemStreamOutcome - - -def test_worker_classifies_graceful_close_before_first_message(): - worker = TaskHubGrpcWorker( - resiliency_options=GrpcWorkerResiliencyOptions(silent_disconnect_timeout_seconds=5.0) - ) - outcome = worker._classify_stream_outcome( - saw_message=False, - timed_out=False, - ) - assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE - - -def test_worker_classifies_graceful_close_after_message(): - worker = TaskHubGrpcWorker() - outcome = worker._classify_stream_outcome( - saw_message=True, - timed_out=False, - ) - assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE - - -def test_worker_counts_only_transport_failures_for_recreation(): - worker = TaskHubGrpcWorker( - resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) - ) - assert worker._should_count_worker_failure(grpc.StatusCode.UNAVAILABLE) is True - assert worker._should_count_worker_failure(grpc.StatusCode.UNAUTHENTICATED) is False - - -def test_worker_does_not_recreate_caller_owned_channel(): - worker = TaskHubGrpcWorker(channel=MagicMock()) - assert worker._can_recreate_channel() is False -``` - -- [x] **Step 2: Run the tests to verify they fail** - -Run: `python -m pytest tests/durabletask/test_worker_resiliency.py -v` - -Expected: FAIL because the worker does not expose explicit stream-outcome helpers yet and still uses ad hoc reconnect bookkeeping. - -- [x] **Step 3: Write the minimal implementation** - -```python -class _WorkItemStreamOutcome(Enum): - SHUTDOWN = "shutdown" - GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE = "graceful_close_before_first_message" - GRACEFUL_CLOSE_AFTER_MESSAGE = "graceful_close_after_message" - SILENT_DISCONNECT = "silent_disconnect" - - -def _classify_stream_outcome(self, *, saw_message: bool, timed_out: bool) -> _WorkItemStreamOutcome: - if timed_out: - return _WorkItemStreamOutcome.SILENT_DISCONNECT - if saw_message: - return _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE - return _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE - - -def _should_count_worker_failure(self, status_code: grpc.StatusCode) -> bool: - return is_worker_transport_failure(status_code) - - -def _can_recreate_channel(self) -> bool: - return self._channel is None - - -hello_timeout = self._resiliency_options.hello_timeout_seconds -current_stub.Hello(empty_pb2.Empty(), timeout=hello_timeout) - -queue_timeout = self._resiliency_options.silent_disconnect_timeout_seconds or None -work_item = await asyncio.wait_for( - loop.run_in_executor(None, work_item_queue.get), - timeout=queue_timeout, -) - -delay = get_full_jitter_delay_seconds( - conn_retry_count, - base_seconds=self._resiliency_options.reconnect_backoff_base_seconds, - cap_seconds=self._resiliency_options.reconnect_backoff_cap_seconds, -) - -if work_item.HasField("healthPing"): - failure_tracker.record_success() - continue -``` - -- [x] **Step 4: Run the worker tests** - -Run: `python -m pytest tests/durabletask/test_worker_resiliency.py -v` - -Expected: PASS for the worker classification and ownership tests. - -- [x] **Step 5: Commit** - -```bash -git add durabletask/worker.py tests/durabletask/test_worker_resiliency.py -git commit -m "Harden worker gRPC stream reconnect behavior" -``` - -### Task 5: Add sync client channel recreation - -**Files:** -- Modify: `durabletask/client.py` -- Modify: `tests/durabletask/test_client.py` - -- [x] **Step 1: Write the failing sync client recreation tests** - -```python -import grpc -import pytest -from unittest.mock import MagicMock, patch - -from durabletask.client import TaskHubGrpcClient -from durabletask.grpc_options import GrpcClientResiliencyOptions - - -def test_sync_client_recreates_sdk_owned_channel_after_repeated_unavailable(): - first_channel = MagicMock(name="first-channel") - second_channel = MagicMock(name="second-channel") - first_stub = MagicMock() - first_stub.GetInstance.side_effect = grpc.RpcError() - second_stub = MagicMock() - second_stub.GetInstance.return_value = MagicMock(exists=False) - - rpc_error = MagicMock(spec=grpc.RpcError) - rpc_error.code.return_value = grpc.StatusCode.UNAVAILABLE - first_stub.GetInstance.side_effect = rpc_error - - with patch("durabletask.client.shared.get_grpc_channel", side_effect=[first_channel, second_channel]), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] - ): - client = TaskHubGrpcClient( - host_address="localhost:4001", - resiliency_options=GrpcClientResiliencyOptions( - channel_recreate_failure_threshold=1, - min_recreate_interval_seconds=0.0, - ), - ) - with pytest.raises(grpc.RpcError): - client.get_orchestration_state("abc") - client.get_orchestration_state("abc") - - assert client._channel is second_channel - - -def test_sync_client_does_not_count_long_poll_deadline(): - rpc_error = MagicMock(spec=grpc.RpcError) - rpc_error.code.return_value = grpc.StatusCode.DEADLINE_EXCEEDED - stub = MagicMock() - stub.WaitForInstanceStart.side_effect = rpc_error - - with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub - ): - client = TaskHubGrpcClient( - resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1) - ) - with pytest.raises(TimeoutError): - client.wait_for_orchestration_start("abc") - assert client._client_failure_tracker.consecutive_failures == 0 -``` - -- [x] **Step 2: Run the tests to verify they fail** - -Run: `python -m pytest tests/durabletask/test_client.py -k "recreates_sdk_owned_channel or long_poll_deadline" -v` - -Expected: FAIL because client calls still go directly through the stub and the client has no failure tracker or channel recreation path. - -- [x] **Step 3: Write the minimal implementation** - -```python -self._client_failure_tracker = FailureTracker( - self._resiliency_options.channel_recreate_failure_threshold -) -self._last_recreate_time = 0.0 -self._recreate_lock = threading.Lock() - -def _invoke_unary(self, method_name: str, request: Any, *, timeout: Optional[int] = None): - method = getattr(self._stub, method_name) - try: - if timeout is None: - response = method(request) - else: - response = method(request, timeout=timeout) - except grpc.RpcError as rpc_error: - if is_client_transport_failure(method_name, rpc_error.code()): - should_recreate = self._client_failure_tracker.record_failure() - if should_recreate: - self._maybe_recreate_channel() - else: - self._client_failure_tracker.record_success() - raise - else: - self._client_failure_tracker.record_success() - return response - -def _maybe_recreate_channel(self) -> None: - if not self._owns_channel: - return - with self._recreate_lock: - now = time.monotonic() - if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: - return - old_channel = self._channel - self._channel = shared.get_grpc_channel( - host_address=self._host_address, - secure_channel=self._secure_channel, - interceptors=self._interceptors, - channel_options=self._channel_options, - ) - self._stub = stubs.TaskHubSidecarServiceStub(self._channel) - self._last_recreate_time = now - self._client_failure_tracker.record_success() - threading.Timer(30.0, old_channel.close).start() -``` - -- [x] **Step 4: Run the tests to verify they pass** - -Run: `python -m pytest tests/durabletask/test_client.py -k "recreates_sdk_owned_channel or long_poll_deadline" -v` - -Expected: PASS for both new sync client tests and no regressions in the existing client construction tests. - -- [x] **Step 5: Commit** - -```bash -git add durabletask/client.py tests/durabletask/test_client.py -git commit -m "Add sync client gRPC channel recreation" -``` - -### Task 6: Add async client channel recreation - -**Files:** -- Modify: `durabletask/client.py` -- Modify: `tests/durabletask/test_client.py` - -- [x] **Step 1: Write the failing async client recreation tests** - -```python -import grpc -import pytest -from unittest.mock import AsyncMock, MagicMock, patch - -from durabletask.client import AsyncTaskHubGrpcClient -from durabletask.grpc_options import GrpcClientResiliencyOptions - - -@pytest.mark.asyncio -async def test_async_client_recreates_sdk_owned_channel_after_unavailable(): - rpc_error = MagicMock(spec=grpc.aio.AioRpcError) - rpc_error.code.return_value = grpc.StatusCode.UNAVAILABLE - - first_stub = MagicMock() - first_stub.GetInstance = AsyncMock(side_effect=rpc_error) - second_stub = MagicMock() - second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) - - with patch("durabletask.client.shared.get_async_grpc_channel", side_effect=[MagicMock(), MagicMock()]), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] - ): - client = AsyncTaskHubGrpcClient( - host_address="localhost:4001", - resiliency_options=GrpcClientResiliencyOptions( - channel_recreate_failure_threshold=1, - min_recreate_interval_seconds=0.0, - ), - ) - with pytest.raises(grpc.aio.AioRpcError): - await client.get_orchestration_state("abc") - await client.get_orchestration_state("abc") - - -@pytest.mark.asyncio -async def test_async_client_does_not_count_wait_for_orchestration_deadline(): - rpc_error = MagicMock(spec=grpc.aio.AioRpcError) - rpc_error.code.return_value = grpc.StatusCode.DEADLINE_EXCEEDED - stub = MagicMock() - stub.WaitForInstanceCompletion = AsyncMock(side_effect=rpc_error) - - with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( - "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub - ): - client = AsyncTaskHubGrpcClient( - resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1) - ) - with pytest.raises(TimeoutError): - await client.wait_for_orchestration_completion("abc") - assert client._client_failure_tracker.consecutive_failures == 0 -``` - -- [x] **Step 2: Run the tests to verify they fail** - -Run: `python -m pytest tests/durabletask/test_client.py -k "async_client_recreates_sdk_owned_channel or async_client_does_not_count" -v` - -Expected: FAIL because the async client still awaits stub methods directly and has no async-safe recreation path. - -- [x] **Step 3: Write the minimal implementation** - -```python -self._client_failure_tracker = FailureTracker( - self._resiliency_options.channel_recreate_failure_threshold -) -self._recreate_lock = asyncio.Lock() -self._last_recreate_time = 0.0 - -async def _invoke_unary(self, method_name: str, request: Any, *, timeout: Optional[int] = None): - method = getattr(self._stub, method_name) - try: - if timeout is None: - response = await method(request) - else: - response = await method(request, timeout=timeout) - except grpc.aio.AioRpcError as rpc_error: - if is_client_transport_failure(method_name, rpc_error.code()): - should_recreate = self._client_failure_tracker.record_failure() - if should_recreate: - await self._maybe_recreate_channel() - else: - self._client_failure_tracker.record_success() - raise - else: - self._client_failure_tracker.record_success() - return response - -async def _maybe_recreate_channel(self) -> None: - if not self._owns_channel: - return - async with self._recreate_lock: - now = time.monotonic() - if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: - return - old_channel = self._channel - self._channel = shared.get_async_grpc_channel( - host_address=self._host_address, - secure_channel=self._secure_channel, - interceptors=self._interceptors, - channel_options=self._channel_options, - ) - self._stub = stubs.TaskHubSidecarServiceStub(self._channel) - self._last_recreate_time = now - self._client_failure_tracker.record_success() - asyncio.create_task(self._close_retired_channel(old_channel)) - - -async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None: - await asyncio.sleep(30.0) - await channel.close() -``` - -- [x] **Step 4: Run the tests to verify they pass** - -Run: `python -m pytest tests/durabletask/test_client.py -k "async_client_recreates_sdk_owned_channel or async_client_does_not_count" -v` - -Expected: PASS for the async recreation tests and no regressions in the existing async client construction tests. - -- [x] **Step 5: Commit** - -```bash -git add durabletask/client.py tests/durabletask/test_client.py -git commit -m "Add async client gRPC channel recreation" -``` - -### Task 7: Update changelogs and run final verification - -**Files:** -- Modify: `CHANGELOG.md` -- Modify: `durabletask-azuremanaged/CHANGELOG.md` -- Modify: `docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md` (only if the implementation changed the agreed design) -- Modify: `docs/superpowers/plans/2026-04-23-grpc-resiliency.md` (check off completed steps only after execution) - -- [x] **Step 1: Add the changelog entries** - -```markdown -## Unreleased - -### Added - -- Added automatic gRPC channel healing for SDK-owned clients and workers, with new resiliency option types for tuning hello deadlines, silent-disconnect detection, recreate thresholds, and recreate cooldowns. -``` - -```markdown -## Unreleased - -### Added - -- Added pass-through support for the new gRPC resiliency option types on Azure Managed clients and workers. -``` - -- [x] **Step 2: Run the focused tests** - -Run: - -```bash -python -m pytest tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -v -``` - -Expected: PASS for all new and touched unit tests. - -- [x] **Step 3: Run lint on the changed Python files** - -Run: - -```bash -python -m flake8 durabletask/grpc_options.py durabletask/internal/grpc_resiliency.py durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -``` - -Expected: no output - -- [x] **Step 4: Run the full test suite** - -Run: - -```bash -python -m pytest -``` - -Expected: PASS across the repository, including the existing orchestration and Azure Managed test suites. - -- [x] **Step 5: Commit** - -```bash -git add CHANGELOG.md durabletask-azuremanaged/CHANGELOG.md durabletask/grpc_options.py durabletask/internal/grpc_resiliency.py durabletask/client.py durabletask/worker.py durabletask-azuremanaged/durabletask/azuremanaged/client.py durabletask-azuremanaged/durabletask/azuremanaged/worker.py tests/durabletask/test_grpc_resiliency.py tests/durabletask/test_worker_resiliency.py tests/durabletask/test_client.py tests/durabletask-azuremanaged/test_grpc_resiliency.py -git commit -m "Add gRPC connection resiliency" -``` diff --git a/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md b/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md deleted file mode 100644 index ce1b586..0000000 --- a/docs/superpowers/specs/2026-04-23-grpc-resiliency-design.md +++ /dev/null @@ -1,325 +0,0 @@ -# gRPC connection resiliency design - -## Problem statement - -`durabletask-python` already has basic gRPC retry policy configuration, -keepalive channel settings, and worker-side reconnect logic. It does not yet -have the stronger connection-healing behavior added in -`durabletask-dotnet` PR 708: - -- worker-side silent-disconnect detection for long-lived work-item streams -- consistent transport-failure classification -- client-side channel recreation after repeated transport failures -- shared backoff and threshold logic across connection-owning components - -The current gap shows up most clearly when a channel becomes stale or -half-open. The worker may continue retrying around a poisoned stream without a -clear distinction between graceful close and silent disconnect, and clients may -keep reusing a bad channel until the application recreates the client. - -## Goals - -- Detect and heal stale or silently disconnected gRPC connections in the worker - and in sync and async clients. -- Enable the new behavior by default with conservative values and explicit - override and disable knobs. -- Preserve existing protocol behavior and support for caller-supplied channels. -- Keep low-level gRPC channel options separate from SDK-managed resiliency - policy. -- Add focused regression tests for failure classification, backoff, and channel - recreation. - -## Non-goals - -- Redesign the public orchestration APIs or the sidecar protocol. -- Add general channel pooling or multi-endpoint load-balancing support. -- Automatically recreate caller-supplied channels in this iteration. -- Expand every possible raw gRPC channel knob as part of this work. - -## Proposed public API - -Add two new option dataclasses in `durabletask.grpc_options`. - -### `GrpcWorkerResiliencyOptions` - -Used by `TaskHubGrpcWorker` and Azure Managed worker wrappers. - -| Field | Default | Meaning | -| --- | --- | --- | -| `hello_timeout_seconds` | `30.0` | Deadline for the initial `Hello` handshake on a fresh connection. | -| `silent_disconnect_timeout_seconds` | `120.0` | Maximum idle period on the `GetWorkItems` stream before the worker treats the connection as stale. A value `<= 0` disables silent-disconnect detection. | -| `channel_recreate_failure_threshold` | `5` | Number of consecutive transport-shaped failures before the worker recreates an SDK-owned channel. A value `<= 0` disables recreation. | -| `reconnect_backoff_base_seconds` | `1.0` | Base delay for reconnect backoff. | -| `reconnect_backoff_cap_seconds` | `30.0` | Maximum reconnect delay. | - -### `GrpcClientResiliencyOptions` - -Used by `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and Azure Managed client -wrappers. - -| Field | Default | Meaning | -| --- | --- | --- | -| `channel_recreate_failure_threshold` | `5` | Number of consecutive transport-shaped unary RPC failures before recreating an SDK-owned channel. A value `<= 0` disables recreation. | -| `min_recreate_interval_seconds` | `30.0` | Minimum interval between channel recreation attempts. | - -### Constructor changes - -Add a new optional `resiliency_options` parameter to these constructors: - -- `TaskHubGrpcWorker` -- `TaskHubGrpcClient` -- `AsyncTaskHubGrpcClient` -- `DurableTaskSchedulerWorker` -- `DurableTaskSchedulerClient` -- `AsyncDurableTaskSchedulerClient` - -If the parameter is omitted, the SDK uses the defaults above. This keeps the -new behavior enabled by default while still allowing targeted disablement. - -`GrpcChannelOptions` remains the place for raw gRPC transport settings such as -keepalive and retry service config. Resiliency policy stays separate because it -controls SDK behavior, not just channel construction. - -## Runtime design - -### Shared internal helpers - -Add a small internal module dedicated to resiliency primitives. It should stay -transport-focused and reusable by the worker and clients. - -Planned responsibilities: - -- full-jitter exponential backoff calculation -- transport-failure classification helpers -- consecutive-failure tracking with reset semantics -- small immutable state objects where atomic swap is needed - -The worker and client should share the same definition of -"transport-shaped failure" instead of maintaining separate ad hoc rules. - -### Worker behavior - -The worker keeps its current high-level reconnect loop but replaces the -connection-health logic with clearer internal pieces. - -#### Fresh connection establishment - -When the worker creates an SDK-owned channel, it: - -1. builds the channel and stub as it does today -2. sends `Hello` with `hello_timeout_seconds` -3. treats `UNAVAILABLE` and `DEADLINE_EXCEEDED` on that handshake as transport - failures - -Successful `Hello` resets the worker reconnect attempt counter. - -#### Stream monitoring - -Wrap the `GetWorkItems` stream in an internal monitor that tracks two things: - -- whether any message has ever been observed on the stream -- whether the stream has remained idle longer than - `silent_disconnect_timeout_seconds` - -The monitor reports one of these outcomes: - -- `shutdown`: worker shutdown was requested -- `message_received`: at least one message arrived and normal processing - continues -- `graceful_close_before_first_message`: peer closed the stream before the - worker observed any message -- `graceful_close_after_message`: peer closed the stream after at least one - message was observed -- `silent_disconnect`: the stream remained idle past the configured timeout - -The outer worker loop uses those outcomes as follows: - -- `message_received`: reset health counters -- `graceful_close_before_first_message`: immediately reset the current stream - and force a fresh SDK-owned channel on the next connect attempt -- `graceful_close_after_message`: immediately reset the current stream and - reconnect without incrementing the transport-failure counter -- `silent_disconnect`: count as channel poison -- `shutdown`: exit cleanly - -This keeps rolling upgrades and normal peer-driven reconnects from inflating -the failure threshold while still forcing SDK-owned workers to establish a -fresh channel after graceful stream closures. - -#### Failure counting and recreation - -The worker increments the consecutive-failure counter only for -transport-shaped failures: - -- `UNAVAILABLE` -- `DEADLINE_EXCEEDED` -- explicit silent-disconnect timeout - -It does not increment the counter for errors that channel recreation is -unlikely to fix, such as: - -- `UNAUTHENTICATED` -- `NOT_FOUND` -- orchestration or activity execution failures -- graceful stream closures before or after work items - -When the threshold is reached and the worker owns the channel, it recreates the -channel and stub. Graceful stream closures also force an immediate fresh -SDK-owned channel even though they do not increment the threshold. When the -worker does not own the channel, it keeps retrying the existing transport and -logs that the channel could not be recreated. - -### Client behavior - -Both sync and async clients route unary RPCs through a small internal invoker -helper instead of calling generated stub methods directly. - -The helper: - -- invokes the target unary RPC -- classifies the outcome -- updates a shared failure counter -- schedules channel recreation when the threshold is crossed - -#### Counted failures - -Count these failures toward the client recreation threshold: - -- `UNAVAILABLE` -- `DEADLINE_EXCEEDED` for ordinary unary calls - -Do not count deadline failures for long-poll methods because those calls are -expected to wait: - -- `wait_for_orchestration_start` -- `wait_for_orchestration_completion` -- async variants of those methods - -Successful replies and application-level RPC errors reset the failure counter, -because they prove the underlying transport is still usable. - -#### Channel recreation mechanics - -When the threshold is reached and the client owns the channel: - -1. enforce `min_recreate_interval_seconds` -2. allow only one recreation in flight at a time -3. build a fresh channel and stub with the existing host, interceptors, secure - channel flag, and `GrpcChannelOptions` -4. atomically swap the active channel and stub -5. retire the previous channel after a short grace period - -The failing RPC still fails normally. The recreated channel benefits later RPCs. - -If the caller supplied the channel, the client still tracks and logs transport -failures but does not attempt replacement. - -### Retiring replaced channels - -Closing the old channel immediately after a successful swap risks interrupting -in-flight work that captured the old stub before the swap. To avoid that, the -SDK keeps replaced SDK-owned channels alive for a short grace period and then -closes them. - -The implementation can use a small internal scheduler that is appropriate for -the transport: - -- sync clients and the worker: daemon timer or background thread -- async clients: background task plus `asyncio.sleep` - -All retired channels are also closed during final client or worker shutdown. - -## File-level implementation plan - -### `durabletask/grpc_options.py` - -- add `GrpcWorkerResiliencyOptions` -- add `GrpcClientResiliencyOptions` -- add validation for positive durations when enabled - -### `durabletask/internal/grpc_resiliency.py` - -Add shared internals for: - -- backoff calculation -- failure classification -- failure-threshold tracking -- small helper types used by worker and client code - -### `durabletask/worker.py` - -- accept `resiliency_options` -- replace the current ad hoc reconnect bookkeeping with the shared helpers -- add hello deadline handling -- add stream-outcome monitoring -- recreate SDK-owned channels when the threshold is crossed - -### `durabletask/client.py` - -- accept `resiliency_options` in sync and async clients -- centralize unary RPC invocation through internal helpers -- add single-flight channel recreation and cooldown logic -- retain current ownership semantics for caller-supplied channels - -### Azure Managed wrappers - -Thread the new `resiliency_options` parameter through: - -- `DurableTaskSchedulerWorker` -- `DurableTaskSchedulerClient` -- `AsyncDurableTaskSchedulerClient` - -No Azure-specific recreation behavior is required in this iteration because the -wrappers already build SDK-owned channels through the base client and worker -constructors. - -## Testing strategy - -Add focused unit tests for the new behavior. - -### Options and helper tests - -- new resiliency option validation -- full-jitter backoff bounds and cap behavior -- failure counter reset and threshold logic -- transport-failure classification rules - -### Worker tests - -- hello deadline failure counts toward recreation -- silent-disconnect timeout is detected and classified -- graceful stream closes force a fresh SDK-owned connection without increasing - the failure counter -- user-supplied channels are not recreated or closed - -### Client tests - -- repeated `UNAVAILABLE` failures trigger recreation for SDK-owned channels -- long-poll `DEADLINE_EXCEEDED` does not count toward recreation -- application-level RPC errors reset the counter -- recreation is single-flight and cooldown-limited -- replaced channels are closed after the grace period -- caller-supplied channels are observed but not replaced - -### Regression coverage - -Existing client and worker tests should continue to pass without requiring users -to opt into the new behavior. - -## Compatibility and rollout - -- The change is backward compatible because all new constructor parameters are - optional. -- The new behavior is enabled by default for SDK-owned channels only. -- Caller-supplied channels preserve existing ownership and lifecycle behavior. -- No protocol changes are required between the Python SDK and the sidecar. -- The changelog should describe the new automatic healing of stale gRPC worker - and client connections and the new resiliency option types. - -## Decision summary - -Implement parity-inspired connection healing from `durabletask-dotnet` PR 708 -by adding explicit worker stream monitoring, shared failure classification, and -client-side channel recreation for SDK-owned channels. Keep raw gRPC channel -configuration separate from SDK resiliency policy and leave broader channel -pooling and user-supplied channel recreation out of this iteration.