From 7929e2358102a9ca88e40739c8362a0d4a8d4ffd Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Tue, 5 May 2026 20:10:27 -0700 Subject: [PATCH] feat: add plugin interface --- .../execution.py | 389 +++--- .../lambda_service.py | 9 + .../plugin.py | 351 ++++++ src/aws_durable_execution_sdk_python/state.py | 22 +- tests/e2e/checkpoint_response_int_test.py | 40 +- tests/e2e/execution_int_test.py | 26 +- .../e2e/map_with_concurrent_waits_int_test.py | 2 + tests/execution_test.py | 293 +++++ tests/logger_test.py | 5 + tests/plugin_test.py | 1116 +++++++++++++++++ tests/state_test.py | 540 ++++++++ 11 files changed, 2589 insertions(+), 204 deletions(-) create mode 100644 src/aws_durable_execution_sdk_python/plugin.py create mode 100644 tests/plugin_test.py diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 977834de..25596d16 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -4,9 +4,9 @@ import functools import json import logging +import sys from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Any, cast, Callable from aws_durable_execution_sdk_python.context import DurableContext @@ -18,9 +18,14 @@ InvocationError, SuspendExecution, ) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + PluginExecutor, +) from aws_durable_execution_sdk_python.lambda_service import ( DurableServiceClient, ErrorObject, + InvocationStatus, LambdaClient, Operation, OperationUpdate, @@ -147,12 +152,6 @@ def from_durable_execution_invocation_input( ) -class InvocationStatus(Enum): - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - PENDING = "PENDING" - - @dataclass(frozen=True) class DurableExecutionInvocationOutput: """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. @@ -205,11 +204,23 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: Boto3LambdaClient | None = None, + plugins: list[DurableExecutionPlugin] | None = None, ) -> Callable[[Any, LambdaContext], Any]: + """ + Decorator to create a durable execution handler. + + Args: + func: The user function to decorate + boto3_client: Optional boto3 Lambda client to use + plugins: Optional list of plugins to use (EXPERIMENTAL: This + parameter may change or be removed.) + """ # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, boto3_client=boto3_client, plugins=plugins + ) else: logger.debug("Starting durable execution handler...") @@ -217,6 +228,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: executor = DurableExecutionExecutor( cast(Callable[[Any, DurableContext], Any], func), boto3_client, + plugins, event, context, ) @@ -230,43 +242,57 @@ def __init__( self, func: Callable[[Any, DurableContext], Any], boto3_client: Boto3LambdaClient | None, + plugins: list[DurableExecutionPlugin] | None, event: Any, context: LambdaContext, ): - self.func = func - self.boto3_client = boto3_client - self.event = event - self.context = context - self.invocation_input = self._parse_invocation_input(event) - self.service_client = self._parse_service_client(event, boto3_client) + self._func = func + self._context = context + self._plugin_executor = PluginExecutor(plugins) + + self._print_durable_execution_arn(event) + + # parsing input, which may throw exceptions + self._invocation_input = self._parse_invocation_input(event) + self._service_client = self._parse_service_client(event, boto3_client) + + self._execution_state = ExecutionState( + durable_execution_arn=self._invocation_input.durable_execution_arn, + initial_checkpoint_token=self._invocation_input.checkpoint_token, + operations={}, + plugin_executor=self._plugin_executor, + service_client=self._service_client, + replay_status=ReplayStatus.NEW, + ) + + @staticmethod + def _print_durable_execution_arn(event): + try: + arn = event.get("DurableExecutionArn") + logger.debug("durableExecutionArn: %s", arn) + except (KeyError, TypeError, AttributeError): + logger.warning("Durable Execution ARN not found in event") def _parse_invocation_input(self, event: Any) -> DurableExecutionInvocationInput: # event likely only to be DurableExecutionInvocationInputWithClient when directly injected by test framework - invocation_input: ( - DurableExecutionInvocationInputWithClient | DurableExecutionInvocationInput - ) if isinstance(event, DurableExecutionInvocationInputWithClient): - invocation_input = event - else: - try: - invocation_input = DurableExecutionInvocationInput.from_json_dict(event) - except (KeyError, TypeError, AttributeError): - msg = ( - "Unexpected payload provided to start the durable execution. " - "Check your resource configurations to confirm the durability is set." - ) - # throws ExecutionError to terminate the invocation - self._handle_execution_output( - exception=ExecutionError(msg), retryable=True - ) - # add a redundant raise to make type checker happy - raise ExecutionError(msg) - - logger.debug("durableExecutionArn: %s", invocation_input.durable_execution_arn) - return invocation_input + return event + try: + return DurableExecutionInvocationInput.from_json_dict(event) + except (KeyError, TypeError, AttributeError): + msg = ( + "Unexpected payload provided to start the durable execution. " + "Check your resource configurations to confirm the durability is set." + ) + # throws ExecutionError to terminate the invocation + self._handle_execution_exception( + exception=ExecutionError(msg), retryable=True + ) + # add a redundant raise to make type checker happy + raise ExecutionError(msg) @staticmethod - def _parse_service_client(event, boto3_client): + def _parse_service_client(event, boto3_client) -> DurableServiceClient: if isinstance(event, DurableExecutionInvocationInputWithClient): return event.service_client elif boto3_client: @@ -276,19 +302,11 @@ def _parse_service_client(event, boto3_client): return LambdaClient.initialize_client() def execute(self): - execution_state: ExecutionState = ExecutionState( - durable_execution_arn=self.invocation_input.durable_execution_arn, - initial_checkpoint_token=self.invocation_input.checkpoint_token, - operations={}, - service_client=self.service_client, - replay_status=ReplayStatus.NEW, - ) - try: - execution_state.fetch_paginated_operations( - self.invocation_input.initial_execution_state.operations, - self.invocation_input.checkpoint_token, - self.invocation_input.initial_execution_state.next_marker, + self._execution_state.fetch_paginated_operations( + self._invocation_input.initial_execution_state.operations, + self._invocation_input.checkpoint_token, + self._invocation_input.initial_execution_state.next_marker, ) except BotoClientError as e: # Non-retryable Durable API errors (e.g., customer configuration issues, @@ -299,29 +317,16 @@ def execute(self): "without retry.", extra=e.build_logger_extras(), ) - return self._handle_execution_output( + return self._handle_execution_exception( exception=e, retryable=e.is_retryable() ) - execution_state.mark_replaying_if_prior_operations_exist() + self._execution_state.mark_replaying_if_prior_operations_exist() - raw_input_payload: str | None = execution_state.get_input_payload() - - # Python RIC LambdaMarshaller just uses standard json deserialization for event - # https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46 - input_event: MutableMapping[str, Any] = {} - if raw_input_payload and raw_input_payload.strip(): - try: - input_event = json.loads(raw_input_payload) - except json.JSONDecodeError as e: - logger.exception( - "Failed to parse input payload as JSON: payload: %r", - raw_input_payload, - ) - self._handle_execution_output(exception=e, retryable=True) + input_event: Any = self._get_input_event() durable_context: DurableContext = DurableContext.from_lambda_context( - state=execution_state, lambda_context=self.context + state=self._execution_state, lambda_context=self._context ) # Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing @@ -329,113 +334,121 @@ def execute(self): ThreadPoolExecutor( max_workers=2, thread_name_prefix="dex-handler" ) as executor, - contextlib.closing(execution_state) as execution_state, + contextlib.closing(self._plugin_executor), + contextlib.closing(self._execution_state), ): + # execute the plugins + self._plugin_executor.on_invocation_start( + durable_execution_arn=self._invocation_input.durable_execution_arn, + context=self._context, + execution_operation=self._execution_state.get_execution_operation(), + is_replaying=self._execution_state.is_replaying(), + ) + # Thread 1: Run background checkpoint processing - executor.submit(execution_state.checkpoint_batches_forever) + executor.submit(self._execution_state.checkpoint_batches_forever) # Thread 2: Execute user function logger.debug( - "%s entering user-space...", self.invocation_input.durable_execution_arn + "%s entering user-space...", + self._invocation_input.durable_execution_arn, ) - user_future = executor.submit(self.func, input_event, durable_context) + user_future = executor.submit(self._func, input_event, durable_context) - logger.debug( - "%s waiting for user code completion...", - self.invocation_input.durable_execution_arn, - ) + return self._handle_user_future_result(user_future) - try: - # Background checkpointing errors will propagate through CompletionEvent.wait() as BackgroundThreadError - result = user_future.result() + def _handle_user_future_result(self, user_future): + logger.debug( + "%s waiting for user code completion...", + self._invocation_input.durable_execution_arn, + ) - # done with userland - logger.debug( - "%s exiting user-space...", - self.invocation_input.durable_execution_arn, - ) - serialized_result = self._handle_large_result(execution_state, result) + try: + # Background checkpointing errors will propagate through CompletionEvent.wait() as BackgroundThreadError + result = user_future.result() - return self._handle_execution_output(result=serialized_result) + # done with userland + logger.debug( + "%s exiting user-space...", self._invocation_input.durable_execution_arn + ) + serialized_result = self._handle_large_result(result) - except BackgroundThreadError as bg_error: - # Background checkpoint system failed - propagated through CompletionEvent - # Do not attempt to checkpoint anything, just terminate immediately - cause = bg_error.source_exception + return self._handle_execution_output(result=serialized_result) - if isinstance(cause, BotoClientError): + except BackgroundThreadError as bg_error: + # Background checkpoint system failed - propagated through CompletionEvent + # Do not attempt to checkpoint anything, just terminate immediately + cause = bg_error.source_exception + + if isinstance(cause, BotoClientError): + cause_extra = cause.build_logger_extras() + logger.exception("Checkpoint processing failed", extra=cause_extra) + # Non-retryable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not cause.is_retryable(): logger.exception( - "Checkpoint processing failed", - extra=cause.build_logger_extras(), + "Non-retryable Durable API error from background thread. Must fail execution " + "without retry.", + extra=cause_extra, ) - # Non-retryable Durable API errors (e.g., customer configuration issues, - # 4xx client errors) will never succeed on retry — fail the execution immediately. - if not cause.is_retryable(): - logger.exception( - "Non-retryable Durable API error from background thread. Must fail execution " - "without retry.", - extra=cause.build_logger_extras(), - ) - else: - logger.exception("Checkpoint processing failed") - - retryable = ( - not isinstance(cause, BotoClientError) or cause.is_retryable() - ) - return self._handle_execution_output( - exception=cause, retryable=retryable - ) + else: + logger.exception("Checkpoint processing failed") - except SuspendExecution: - # User code suspended - stop background checkpointing thread - logger.debug("Suspending execution...") - return self._handle_execution_output(status=InvocationStatus.PENDING) + retryable = not isinstance(cause, BotoClientError) or cause.is_retryable() + return self._handle_execution_exception( + exception=cause, retryable=retryable + ) + + except SuspendExecution: + # User code suspended - stop background checkpointing thread + logger.debug("Suspending execution...") + return self._handle_execution_output(status=InvocationStatus.PENDING) - except CheckpointError as e: - # Checkpoint system is broken - stop background thread and exit immediately + except CheckpointError as e: + # Checkpoint system is broken - stop background thread and exit immediately + logger.exception( + "Checkpoint system failed", + extra=e.build_logger_extras(), + ) + # Terminate Lambda invocation immediately and have it be retried if retryable + return self._handle_execution_exception( + exception=e, retryable=e.is_retryable() + ) + except InvocationError as e: + if e.is_retryable(): + logger.exception("Invocation error. Must terminate.") + else: + # Non-retryable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. logger.exception( - "Checkpoint system failed", - extra=e.build_logger_extras(), + "Non-retryable Durable API error. Must fail execution without retry.", + extra=e.build_logger_extras(), # type: ignore[attr-defined] ) + return self._handle_execution_exception( + exception=e, retryable=e.is_retryable() + ) + except ExecutionError as e: + logger.exception("Execution error. Must fail execution without retry.") + return self._handle_execution_exception(exception=e, retryable=False) + except Exception as user_exception: + # all user-space errors go here + logger.exception("Execution failed") + + try: + error = self._handle_large_error(exception=user_exception) + except CheckpointError as checkpoint_error: # Terminate Lambda invocation immediately and have it be retried if retryable - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() - ) - except InvocationError as e: - if e.is_retryable(): - logger.exception("Invocation error. Must terminate.") - else: - # Non-retryable Durable API errors (e.g., customer configuration issues, - # 4xx client errors) will never succeed on retry — fail the execution immediately. - logger.exception( - "Non-retryable Durable API error. Must fail execution without retry.", - extra=e.build_logger_extras(), # type: ignore[attr-defined] - ) - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() + return self._handle_execution_exception( + exception=checkpoint_error, + retryable=checkpoint_error.is_retryable(), ) - except ExecutionError as e: - logger.exception("Execution error. Must fail execution without retry.") - return self._handle_execution_output(exception=e) - except Exception as e: - # all user-space errors go here - logger.exception("Execution failed") - - try: - error = self._handle_large_error(execution_state, exception=e) - except CheckpointError as e: - # Terminate Lambda invocation immediately and have it be retried if retryable - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() - ) - # fail without an ErrorObject - return self._handle_execution_output( - status=InvocationStatus.FAILED, error=error - ) + # fail with user's error + return self._handle_execution_output( + status=InvocationStatus.FAILED, error=error + ) - @staticmethod - def _handle_large_result(execution_state: ExecutionState, result: Any) -> str: + def _handle_large_result(self, result: Any) -> str: # large response handling here. Remember if checkpointing to complete, NOT to include # payload in response serialized_result = json.dumps(result) @@ -452,15 +465,12 @@ def _handle_large_result(execution_state: ExecutionState, result: Any) -> str: # Must ensure the result is persisted before returning to Lambda. # Large results exceed Lambda response limits and must be stored durably # before the execution completes. - execution_state.create_checkpoint(success_operation, is_sync=True) + self._execution_state.create_checkpoint(success_operation, is_sync=True) return "" return serialized_result - @staticmethod - def _handle_large_error( - execution_state: ExecutionState, exception: Exception - ) -> ErrorObject | None: + def _handle_large_error(self, exception: Exception) -> ErrorObject | None: # large response handling here. Remember if checkpointing to complete, NOT to include # payload in response error = ErrorObject.from_exception(exception) @@ -476,32 +486,48 @@ def _handle_large_error( # Must ensure the result is persisted before returning to Lambda. # Large results exceed Lambda response limits and must be stored durably # before the execution completes. - execution_state.create_checkpoint_sync(failed_operation) + self._execution_state.create_checkpoint_sync(failed_operation) # return fail without an ErrorObject return None return error + def _handle_execution_exception(self, exception: Exception, retryable: bool): + execution_arn: str | None = getattr(self, "_durable_execution_arn", None) + execution_state: ExecutionState | None = getattr(self, "execution_state", None) + execution_operation = ( + execution_state.get_execution_operation() if execution_state else None + ) + if retryable: + # Throw the error to trigger Lambda retry + self._plugin_executor.on_invocation_end( + execution_arn, + self._context, + execution_operation, + # an invocation output object used by plugin only + DurableExecutionInvocationOutput( + status=InvocationStatus.RETRY, + error=ErrorObject.from_exception(exception), + ), + ) + if exception is sys.exception(): + # exception is the current exception being handled, re-raise it + raise + raise exception + else: + # reset the error field + return self._handle_execution_output( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(exception), + ) + def _handle_execution_output( self, result: str | None = None, error: ErrorObject | None = None, - exception: Exception | None = None, - retryable: bool = False, status: InvocationStatus | None = None, ) -> MutableMapping[str, Any]: - if exception: - if retryable: - # Throw the error to trigger Lambda retry - raise exception - else: - return self._handle_execution_output( - result=result, - error=ErrorObject.from_exception(exception), - status=status, - ) - if error: output = DurableExecutionInvocationOutput( status=InvocationStatus.FAILED, result=result, error=error @@ -514,4 +540,29 @@ def _handle_execution_output( output = DurableExecutionInvocationOutput(status=status) else: raise ValueError("Unexpected durable execution output") + + self._plugin_executor.on_invocation_end( + self._invocation_input.durable_execution_arn, + self._context, + self._execution_state.get_execution_operation(), + output, + ) return output.to_dict() + + def _get_input_event(self) -> Any: + raw_input_payload: str | None = self._execution_state.get_input_payload() + + # Python RIC LambdaMarshaller just uses standard json deserialization for event + # https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46 + if raw_input_payload and raw_input_payload.strip(): + try: + return json.loads(raw_input_payload) + except json.JSONDecodeError as e: + logger.exception( + "Failed to parse input payload as JSON: payload: %r", + raw_input_payload, + ) + self._handle_execution_exception(exception=e, retryable=True) + raise + + return {} diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index aa78e4e8..dd7fc7d3 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -105,6 +105,15 @@ class OperationSubType(Enum): CHAINED_INVOKE = "ChainedInvoke" +class InvocationStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + # Used internally only: the invocation failed and the backend will retry + RETRY = "RETRY" + + @dataclass(frozen=True) class ExecutionDetails: input_payload: str | None = None diff --git a/src/aws_durable_execution_sdk_python/plugin.py b/src/aws_durable_execution_sdk_python/plugin.py new file mode 100644 index 00000000..ef2ed752 --- /dev/null +++ b/src/aws_durable_execution_sdk_python/plugin.py @@ -0,0 +1,351 @@ +import datetime +import logging +from abc import ABC +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, + InvocationStatus, + Operation, + OperationUpdate, +) +from aws_durable_execution_sdk_python.types import LambdaContext + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, + ) + +logger = logging.getLogger(__name__) + + +@dataclass +class OperationStartInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None = None + name: str | None = None + parent_id: str | None = None + start_timestamp: datetime.datetime | None = None + + +@dataclass +class OperationEndInfo(OperationStartInfo): + status: OperationStatus = OperationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + attempt: int | None = None + error: ErrorObject | None = None + + +@dataclass +class AttemptStartInfo(OperationStartInfo): + attempt: int = 1 + + +@dataclass +class AttemptEndInfo(AttemptStartInfo): + succeeded: bool | None = None + end_timestamp: datetime.datetime | None = None + error: ErrorObject | None = None + next_attempt_delay_seconds: int | None = None + + +@dataclass +class InvocationStartInfo: + request_id: str | None + execution_arn: str | None + start_timestamp: datetime.datetime | None + + +@dataclass +class InvocationEndInfo(InvocationStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + error: ErrorObject | None = None + + +@dataclass +class ExecutionStartInfo(InvocationStartInfo): + pass + + +@dataclass +class ExecutionEndInfo(ExecutionStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + error: ErrorObject | None = None + + +class DurableExecutionPlugin(ABC): + """Base class for plugins. Override only the methods you need.""" + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + pass + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + pass + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + pass + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + pass + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + + +class PluginExecutor: + _DEFAULT_MAX_WORKERS = 4 + + def __init__( + self, + plugins: list[DurableExecutionPlugin] | None, + max_workers: int | None = None, + ): + self.plugins = plugins or [] + self._pending_futures: list = [] + self._executor: ThreadPoolExecutor | None = ( + ThreadPoolExecutor( + max_workers=max_workers or self._DEFAULT_MAX_WORKERS, + thread_name_prefix="plugin-executor", + ) + if self.plugins + else None + ) + + def close(self) -> None: + """Shut down the thread pool, waiting for pending tasks to complete.""" + if self._executor is None: + return + self.flush() + self._executor.shutdown(wait=True) + + def flush(self) -> None: + """Wait for all pending plugin tasks to complete. Useful for testing.""" + for future in self._pending_futures: + future.result() + self._pending_futures.clear() + + def _dispatch_plugin(self, plugin: DurableExecutionPlugin, info) -> None: + """Invoke the appropriate plugin callback. Runs inside the thread pool.""" + try: + match info: + case ExecutionEndInfo(): + plugin.on_execution_end(info) + case InvocationEndInfo(): + plugin.on_invocation_end(info) + case ExecutionStartInfo(): + plugin.on_execution_start(info) + case InvocationStartInfo(): + plugin.on_invocation_start(info) + case AttemptEndInfo(): + plugin.on_operation_attempt_end(info) + case OperationEndInfo(): + plugin.on_operation_end(info) + case AttemptStartInfo(): + plugin.on_operation_attempt_start(info) + case OperationStartInfo(): + plugin.on_operation_start(info) + case _: + raise ValueError(f"Unknown info type: {type(info)}") + except Exception: + # log and ignore the exception + logger.exception("Plugin %s exception ignored", plugin.__class__.__name__) + + def execute_plugins(self, info): + if not self.plugins: + return + for plugin in self.plugins: + future = self._executor.submit(self._dispatch_plugin, plugin, info) + self._pending_futures.append(future) + + def on_invocation_start( + self, + durable_execution_arn: str, + context: LambdaContext | None, + execution_operation: Operation | None, + is_replaying: bool, + ) -> None: + aws_request_id = context.aws_request_id if context else None + start_timestamp = ( + execution_operation.start_timestamp if execution_operation else None + ) + + if not is_replaying: + self.execute_plugins( + ExecutionStartInfo( + request_id=aws_request_id, + execution_arn=durable_execution_arn, + start_timestamp=start_timestamp, + ) + ) + + self.execute_plugins( + InvocationStartInfo( + request_id=aws_request_id, + execution_arn=durable_execution_arn, + start_timestamp=start_timestamp, + ) + ) + + def on_invocation_end( + self, + durable_execution_arn: str | None, + context: LambdaContext, + execution_operation: Operation | None, + output: "DurableExecutionInvocationOutput", + ) -> None: + start_timestamp = ( + execution_operation.start_timestamp if execution_operation else None + ) + # the actual end timestamp may be unknown because it's not checkpointed yet + end_timestamp: datetime.datetime = ( + execution_operation.end_timestamp if execution_operation else None + ) or datetime.datetime.now() + request_id = context.aws_request_id if context else None + + self.execute_plugins( + InvocationEndInfo( + request_id=request_id, + execution_arn=durable_execution_arn, + start_timestamp=start_timestamp, + status=output.status, + end_timestamp=end_timestamp, + error=output.error, + ) + ) + + if output.status in [InvocationStatus.SUCCEEDED, InvocationStatus.FAILED]: + self.execute_plugins( + ExecutionEndInfo( + request_id=request_id, + execution_arn=durable_execution_arn, + start_timestamp=start_timestamp, + status=output.status, + end_timestamp=end_timestamp, + error=output.error, + ) + ) + + def on_operation_action(self, operation: Operation | None, update: OperationUpdate): + """Execute any registered plugins for a given operation before it is updated. + + Args: + update: the operation update that is pending checkpoint + """ + if update.action is not OperationAction.START: + return + + self.execute_plugins( + OperationStartInfo( + operation_id=update.operation_id, + operation_type=update.operation_type, + sub_type=update.sub_type, + name=update.name, + parent_id=update.parent_id, + start_timestamp=datetime.datetime.now(), + ) + ) + + if update.operation_type is OperationType.STEP: + attempt = ( + operation.step_details.attempt + if operation and operation.step_details + else 1 + ) + self.execute_plugins( + AttemptStartInfo( + operation_id=update.operation_id, + operation_type=update.operation_type, + sub_type=update.sub_type, + name=update.name, + parent_id=update.parent_id, + start_timestamp=datetime.datetime.now(), + attempt=attempt, + ) + ) + + def on_operation_update(self, operation): + """Execute any registered plugins for a given operation after it is updated. + + Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be + checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED). + + Args: + operation: the operation is just checkpointed + """ + params = dict( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + ) + if operation.step_details and ( + self._is_terminal_status(operation.status) + # PENDING in addition to terminal status + or operation.status is OperationStatus.PENDING + ): + self.execute_plugins( + AttemptEndInfo( + **params, + end_timestamp=operation.end_timestamp, + attempt=operation.step_details.attempt, + succeeded=operation.status is OperationStatus.SUCCEEDED, + error=operation.step_details.error, + ) + ) + + if self._is_terminal_status(operation.status): + attempt = operation.step_details.attempt if operation.step_details else None + self.execute_plugins( + OperationEndInfo( + **params, + end_timestamp=operation.end_timestamp, + status=operation.status, + error=self._extract_error(operation), + attempt=attempt, + ) + ) + + @staticmethod + def _extract_error(operation: Operation): + if operation.step_details and operation.step_details.error: + return operation.step_details.error + if operation.callback_details and operation.callback_details.error: + return operation.callback_details.error + if operation.chained_invoke_details and operation.chained_invoke_details.error: + return operation.chained_invoke_details.error + if operation.context_details and operation.context_details.error: + return operation.context_details.error + return None + + @staticmethod + def _is_terminal_status(status): + return status in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ] diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 83175503..cada4c52 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -7,6 +7,8 @@ import queue import threading import time +from concurrent.futures import Executor +from contextlib import contextmanager from dataclasses import dataclass from enum import Enum from threading import Lock @@ -30,6 +32,9 @@ OperationUpdate, StateOutput, ) +from aws_durable_execution_sdk_python.plugin import ( + PluginExecutor, +) from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock if TYPE_CHECKING: @@ -236,6 +241,7 @@ def __init__( initial_checkpoint_token: str, operations: MutableMapping[str, Operation], service_client: DurableServiceClient, + plugin_executor: PluginExecutor, batcher_config: CheckpointBatcherConfig | None = None, replay_status: ReplayStatus = ReplayStatus.NEW, ): @@ -243,6 +249,7 @@ def __init__( self._current_checkpoint_token: str = initial_checkpoint_token self.operations: MutableMapping[str, Operation] = operations self._service_client: DurableServiceClient = service_client + self._plugin_executor: PluginExecutor = plugin_executor or PluginExecutor(None) self._ordered_checkpoint_lock: OrderedLock = OrderedLock() self._operations_lock: Lock = Lock() @@ -274,7 +281,7 @@ def fetch_paginated_operations( initial_operations: list[Operation], checkpoint_token: str, next_marker: str | None, - ) -> None: + ) -> list[Operation]: """Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe. The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety. @@ -283,6 +290,8 @@ def fetch_paginated_operations( initial_operations: initial operations to be added to ExecutionState checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + Returns: + List of all operations fetched from the Durable Functions API Raises: GetExecutionStateError: If the API call fails. The error is logged @@ -315,6 +324,7 @@ def fetch_paginated_operations( self.operations.update( {op.operation_id: op for op in all_operations} ) + return all_operations def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation @@ -689,12 +699,20 @@ def checkpoint_batches_forever(self) -> None: current_checkpoint_token = output.checkpoint_token # Fetch new operations from the API before unblocking sync waiters - self.fetch_paginated_operations( + updated_operations = self.fetch_paginated_operations( output.new_execution_state.operations, output.checkpoint_token, output.new_execution_state.next_marker, ) + for update in updates: + with self._operations_lock: + op = self.operations.get(update.operation_id) + self._plugin_executor.on_operation_action(op, update) + + for operation in updated_operations: + self._plugin_executor.on_operation_update(operation) + # Signal completion for any synchronous operations for queued_op in batch: if queued_op.completion_event is not None: diff --git a/tests/e2e/checkpoint_response_int_test.py b/tests/e2e/checkpoint_response_int_test.py index c0fd0f50..2fb9141c 100644 --- a/tests/e2e/checkpoint_response_int_test.py +++ b/tests/e2e/checkpoint_response_int_test.py @@ -101,7 +101,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -122,7 +122,7 @@ def my_handler(event, context: DurableContext) -> str: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -164,7 +164,7 @@ def my_handler(event, context: DurableContext) -> list[str]: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -185,7 +185,7 @@ def my_handler(event, context: DurableContext) -> list[str]: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -220,7 +220,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -241,7 +241,7 @@ def my_handler(event, context: DurableContext) -> str: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Wait will suspend, so we expect PENDING status @@ -279,7 +279,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -300,7 +300,7 @@ def my_handler(event, context: DurableContext) -> str: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -388,7 +388,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -409,7 +409,7 @@ def mock_checkpoint( lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -440,7 +440,7 @@ def my_handler(event, context: DurableContext): mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -461,7 +461,7 @@ def my_handler(event, context: DurableContext): lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Invoke will suspend, so we expect PENDING status @@ -499,7 +499,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -520,7 +520,7 @@ def my_handler(event, context: DurableContext) -> str: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -598,7 +598,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -619,7 +619,7 @@ def mock_checkpoint( lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -665,7 +665,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -686,7 +686,7 @@ def my_handler(event, context: DurableContext) -> str: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) @@ -730,7 +730,7 @@ def my_handler(event, context: DurableContext) -> str: mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -751,7 +751,7 @@ def my_handler(event, context: DurableContext) -> str: lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # InvocationError should be re-raised (not wrapped) to trigger Lambda retry diff --git a/tests/e2e/execution_int_test.py b/tests/e2e/execution_int_test.py index 5a884bff..cc1cf067 100644 --- a/tests/e2e/execution_int_test.py +++ b/tests/e2e/execution_int_test.py @@ -135,7 +135,7 @@ def mock_checkpoint( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -157,7 +157,7 @@ def mock_checkpoint( lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Execute the handler @@ -221,7 +221,7 @@ def mock_checkpoint( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -243,7 +243,7 @@ def mock_checkpoint( lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Execute the handler @@ -262,7 +262,7 @@ def mock_checkpoint( 123, "str", extra={ - "executionArn": "test-arn", + "executionArn": "test-arn/execution-1", "operationName": "mystep", "attempt": 1, "operationId": operation_id, @@ -308,7 +308,7 @@ def my_handler(event, context): # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -330,7 +330,7 @@ def my_handler(event, context): lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Execute the handler @@ -409,7 +409,7 @@ def mock_checkpoint_failure( # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -431,7 +431,7 @@ def mock_checkpoint_failure( lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Execute the handler - should propagate the checkpoint error @@ -463,7 +463,7 @@ def my_handler(event: Any, context: DurableContext): # Create test event event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -485,7 +485,7 @@ def my_handler(event: Any, context: DurableContext): lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None # Execute the handler @@ -560,7 +560,7 @@ def mock_checkpoint( mock_client.checkpoint = mock_checkpoint event = { - "DurableExecutionArn": "test-arn", + "DurableExecutionArn": "test-arn/execution-1", "CheckpointToken": "test-token", "InitialExecutionState": { "Operations": [ @@ -581,7 +581,7 @@ def mock_checkpoint( lambda_context.client_context = None lambda_context.identity = None lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001 - lambda_context.invoked_function_arn = "test-arn" + lambda_context.invoked_function_arn = "test-arn/execution-1" lambda_context.tenant_id = None result = my_handler(event, lambda_context) diff --git a/tests/e2e/map_with_concurrent_waits_int_test.py b/tests/e2e/map_with_concurrent_waits_int_test.py index 8ad812e4..62ad7c2b 100644 --- a/tests/e2e/map_with_concurrent_waits_int_test.py +++ b/tests/e2e/map_with_concurrent_waits_int_test.py @@ -42,6 +42,7 @@ OperationUpdate, OperationType, ) +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, ExecutionState, @@ -68,6 +69,7 @@ def _make_state( operations={}, service_client=mock_client, batcher_config=config, + plugin_executor=PluginExecutor([]), ) diff --git a/tests/execution_test.py b/tests/execution_test.py index db13b5a9..c88bf895 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -47,6 +47,7 @@ StepDetails, WaitDetails, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin LARGE_RESULT = "large_success" * 1024 * 1024 @@ -2827,3 +2828,295 @@ def test_handler(event: Any, context: DurableContext) -> dict: _make_invocation_input(mock_client, next_marker="next-page-marker"), _make_lambda_context(), ) + + +# region Plugin Integration Tests + + +class _RecordingPlugin(DurableExecutionPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append(f"execution_end:{info.status.value}") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append(f"invocation_end:{info.status.value}") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info): + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info): + self.calls.append(f"attempt_end:{info.operation_id}") + + +class _FailingPlugin(DurableExecutionPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("plugin boom") + + def on_execution_end(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_start(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("plugin boom") + + +def test_durable_execution_with_plugins_success(): + """Test that plugins receive invocation start/end and execution end on success.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # ExecutionStartInfo dispatches to on_invocation_start in the match block + assert "invocation_start" in plugin.calls + assert "invocation_end:SUCCEEDED" in plugin.calls + assert "execution_end:SUCCEEDED" in plugin.calls + + +def test_durable_execution_with_plugins_failure(): + """Test that plugins receive invocation end and execution end on user error.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "user error" + raise ValueError(msg) + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.FAILED.value + assert "invocation_start" in plugin.calls + assert "invocation_end:FAILED" in plugin.calls + assert "execution_end:FAILED" in plugin.calls + + +def test_durable_execution_with_plugins_pending(): + """Test that plugins receive invocation end with PENDING status on suspend.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + raise SuspendExecution("test") + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.PENDING.value + assert "invocation_start" in plugin.calls + assert "invocation_end:PENDING" in plugin.calls + # Execution end should NOT be fired for PENDING + execution_end_calls = [c for c in plugin.calls if c.startswith("execution_end")] + assert len(execution_end_calls) == 0 + + +def test_durable_execution_with_plugins_retryable_error(): + """Test that plugins receive invocation end with RETRY status on retryable error.""" + mock_client = Mock(spec=DurableServiceClient) + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "Retriable error" + raise InvocationError(msg) + + with pytest.raises(InvocationError): + test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert "invocation_start" in plugin.calls + assert "invocation_end:RETRY" in plugin.calls + + +def test_durable_execution_with_multiple_plugins(): + """Test that multiple plugins all receive callbacks.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin1 = _RecordingPlugin() + plugin2 = _RecordingPlugin() + + @durable_execution(plugins=[plugin1, plugin2]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin1.calls + assert "invocation_start" in plugin2.calls + assert "invocation_end:SUCCEEDED" in plugin1.calls + assert "invocation_end:SUCCEEDED" in plugin2.calls + + +def test_durable_execution_with_failing_plugin_does_not_break_execution(): + """Test that a failing plugin does not prevent the handler from completing.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + failing_plugin = _FailingPlugin() + recording_plugin = _RecordingPlugin() + + @durable_execution(plugins=[failing_plugin, recording_plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + # Execution should still succeed despite the failing plugin + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # The recording plugin should still have been called + assert "invocation_start" in recording_plugin.calls + assert "invocation_end:SUCCEEDED" in recording_plugin.calls + + +def test_durable_execution_with_no_plugins(): + """Test that passing no plugins (None) works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=None) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_with_empty_plugins_list(): + """Test that passing an empty plugins list works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=[]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_decorator_with_plugins_and_boto3_client(): + """Test that plugins parameter works alongside boto3_client parameter.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + # When using DurableExecutionInvocationInputWithClient, boto3_client is ignored + # but we verify the decorator accepts both parameters + @durable_execution(boto3_client=None, plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin.calls + + +# endregion Plugin Integration Tests diff --git a/tests/logger_test.py b/tests/logger_test.py index b6017fa6..1966e276 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -11,6 +11,7 @@ OperationType, ) from aws_durable_execution_sdk_python.logger import Logger, LoggerInterface, LogInfo +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -83,6 +84,7 @@ def exception( initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), ) @@ -227,6 +229,7 @@ def test_logger_with_log_info(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor([]), ) new_info = LogInfo(execution_state_new, "parent2", "op123", "new_name") new_logger = logger.with_log_info(new_info) @@ -377,6 +380,7 @@ def test_logger_replay_no_logging(): operations={"op1": operation}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) mock_logger = Mock() @@ -404,6 +408,7 @@ def test_logger_replay_then_new_logging(): operations={"op1": operation1, "op2": operation2}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(execution_state, "parent123", "test_name", 5) mock_logger = Mock() diff --git a/tests/plugin_test.py b/tests/plugin_test.py new file mode 100644 index 00000000..8dbd6be1 --- /dev/null +++ b/tests/plugin_test.py @@ -0,0 +1,1116 @@ +import datetime +import logging +import unittest +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +from aws_durable_execution_sdk_python.execution import ( + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.lambda_service import ( + CallbackDetails, + ChainedInvokeDetails, + ContextDetails, + ErrorObject, + InvocationStatus, + Operation, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, + OperationUpdate, + StepDetails, +) +from aws_durable_execution_sdk_python.plugin import ( + AttemptEndInfo, + AttemptStartInfo, + DurableExecutionPlugin, + ExecutionEndInfo, + ExecutionStartInfo, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, + PluginExecutor, +) + + +# region Dataclass Tests + + +class TestOperationStartInfo(unittest.TestCase): + def test_required_fields(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.operation_id, "op-1") + self.assertEqual(info.operation_type, OperationType.STEP) + self.assertIsNone(info.sub_type) + self.assertIsNone(info.name) + self.assertIsNone(info.parent_id) + self.assertIsNone(info.start_timestamp) + + def test_all_fields(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_timestamp=ts, + ) + self.assertEqual(info.sub_type, OperationSubType.CALLBACK) + self.assertEqual(info.name, "my-op") + self.assertEqual(info.parent_id, "parent-1") + self.assertEqual(info.start_timestamp, ts) + + +class TestOperationEndInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo)) + + def test_defaults(self): + info = OperationEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.status, OperationStatus.SUCCEEDED) + self.assertIsNone(info.end_timestamp) + self.assertIsNone(info.attempt) + self.assertIsNone(info.error) + + def test_with_error(self): + err = ErrorObject( + message="fail", type="RuntimeError", data=None, stack_trace=None + ) + info = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.FAILED, + error=err, + attempt=3, + ) + self.assertEqual(info.status, OperationStatus.FAILED) + self.assertEqual(info.attempt, 3) + self.assertEqual(info.error.message, "fail") + + +class TestAttemptStartInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo)) + + def test_default_attempt(self): + info = AttemptStartInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.attempt, 1) + + def test_custom_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP, attempt=5 + ) + self.assertEqual(info.attempt, 5) + + +class TestAttemptEndInfo(unittest.TestCase): + def test_inherits_attempt_start_info(self): + self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo)) + + def test_defaults(self): + info = AttemptEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertIsNone(info.succeeded) + self.assertIsNone(info.error) + self.assertIsNone(info.next_attempt_delay_seconds) + + def test_retry_with_delay(self): + err = ErrorObject( + message="timeout", type="TimeoutError", data=None, stack_trace=None + ) + info = AttemptEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + succeeded=False, + error=err, + next_attempt_delay_seconds=30, + ) + self.assertFalse(info.succeeded) + self.assertEqual(info.next_attempt_delay_seconds, 30) + self.assertEqual(info.error.type, "TimeoutError") + + +class TestInvocationStartInfo(unittest.TestCase): + def test_fields(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", + execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", + start_timestamp=ts, + ) + self.assertEqual(info.request_id, "req-1") + self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc") + self.assertEqual(info.start_timestamp, ts) + + +class TestInvocationEndInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(InvocationEndInfo, InvocationStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_failed(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) + info = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_timestamp=ts, + status=InvocationStatus.FAILED, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.error.message, "boom") + + +class TestExecutionStartInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(ExecutionStartInfo, InvocationStartInfo)) + + def test_construction(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.assertEqual(info.request_id, "req-1") + + +class TestExecutionEndInfo(unittest.TestCase): + def test_inherits_execution_start_info(self): + self.assertTrue(issubclass(ExecutionEndInfo, ExecutionStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_with_error(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="crash", type="Error", data=None, stack_trace=None) + info = ExecutionEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_timestamp=ts, + status=InvocationStatus.FAILED, + end_timestamp=ts, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.end_timestamp, ts) + self.assertEqual(info.error.message, "crash") + + +# endregion Dataclass Tests + + +# region DurableExecutionPlugin Tests + + +class TestDurableExecutionPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + exec_start = ExecutionStartInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + exec_end = ExecutionEndInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + inv_start = InvocationStartInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + inv_end = InvocationEndInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + op_start = OperationStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) + op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP) + att_start = AttemptStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) + att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP) + + self.assertIsNone(plugin.on_execution_start(exec_start)) + self.assertIsNone(plugin.on_execution_end(exec_end)) + self.assertIsNone(plugin.on_invocation_start(inv_start)) + self.assertIsNone(plugin.on_invocation_end(inv_end)) + self.assertIsNone(plugin.on_operation_start(op_start)) + self.assertIsNone(plugin.on_operation_end(op_end)) + self.assertIsNone(plugin.on_operation_attempt_start(att_start)) + self.assertIsNone(plugin.on_operation_attempt_end(att_end)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + plugin.on_execution_start( + ExecutionStartInfo(request_id="r", execution_arn="a", start_timestamp=ts) + ) + plugin.on_operation_start( + OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT) + ) + + self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"]) + + def test_cannot_instantiate_abc_directly(self): + """DurableExecutionPlugin is abstract but has no abstract methods, so it can be instantiated via a subclass.""" + self.assertTrue(issubclass(DurableExecutionPlugin, object)) + + +# endregion DurableExecutionPlugin Tests + + +# region PluginExecutor Tests + + +class TestPluginExecutorInit(unittest.TestCase): + def test_init_with_none(self): + executor = PluginExecutor(plugins=None) + self.assertEqual(executor.plugins, []) + + def test_init_with_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertEqual(executor.plugins, []) + + def test_init_with_plugins(self): + p1 = _NoOpPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + self.assertEqual(len(executor.plugins), 2) + + +class TestPluginExecutorEmptyPlugins(unittest.TestCase): + """Tests that PluginExecutor does not create a thread pool when plugins is empty.""" + + def test_no_thread_pool_when_plugins_is_none(self): + executor = PluginExecutor(plugins=None) + self.assertIsNone(executor._executor) + + def test_no_thread_pool_when_plugins_is_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertIsNone(executor._executor) + + def test_thread_pool_created_when_plugins_provided(self): + executor = PluginExecutor(plugins=[_NoOpPlugin()]) + self.assertIsNotNone(executor._executor) + executor.close() + + def test_execute_plugins_is_noop_when_empty(self): + executor = PluginExecutor(plugins=[]) + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + # Should not raise even though _executor is None + executor.execute_plugins(info) + self.assertEqual(executor._pending_futures, []) + + def test_flush_is_noop_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + executor.flush() + + def test_close_is_noop_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + executor.close() + + def test_on_invocation_start_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + ctx = MagicMock() + ctx.aws_request_id = "req-1" + op = MagicMock() + op.start_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + # Should not raise + executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + + def test_on_invocation_end_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + ctx = MagicMock() + ctx.aws_request_id = "req-1" + op = MagicMock() + op.start_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + op.end_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + # Should not raise + executor.on_invocation_end( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + output=output, + ) + + def test_on_operation_action_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = None + + # Should not raise + executor.on_operation_action(None, update) + + def test_on_operation_update_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = None + op.start_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + op.end_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + op.status = OperationStatus.SUCCEEDED + op.step_details = MagicMock() + op.step_details.attempt = 1 + op.step_details.error = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + # Should not raise + executor.on_operation_update(op) + + +class TestPluginExecutorExecutePlugins(unittest.TestCase): + """Tests for the execute_plugins dispatch method.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def tearDown(self): + self.executor.flush() + + def test_dispatch_execution_start_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.executor.execute_plugins(info) + self.assertIn("execution_start:req-1", self.plugin.calls) + + def test_dispatch_execution_end_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.executor.execute_plugins(info) + self.assertIn("execution_end:req-1", self.plugin.calls) + + def test_dispatch_invocation_start_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.executor.execute_plugins(info) + self.assertIn("invocation_start:req-1", self.plugin.calls) + + def test_dispatch_invocation_end_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.executor.execute_plugins(info) + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_dispatch_operation_end_info(self): + info = OperationEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.executor.execute_plugins(info) + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_dispatch_operation_start_info(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.executor.execute_plugins(info) + self.assertIn("operation_start:op-1", self.plugin.calls) + + def test_dispatch_attempt_start_info(self): + info = AttemptStartInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.executor.execute_plugins(info) + self.assertIn("attempt_start:op-1", self.plugin.calls) + + def test_dispatch_attempt_end_info(self): + info = AttemptEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.executor.execute_plugins(info) + self.assertIn("attempt_end:op-1", self.plugin.calls) + + def test_dispatch_unknown_type_logs_exception(self): + """Unknown info types should be caught and logged.""" + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + self.executor.execute_plugins("not a valid info type") + self.executor.flush() + + def test_plugin_exception_is_swallowed(self): + """If a plugin raises, the exception is logged and execution continues.""" + failing_plugin = _FailingPlugin() + tracking_plugin = _TrackingPlugin() + executor = PluginExecutor(plugins=[failing_plugin, tracking_plugin]) + + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + executor.execute_plugins(info) + executor.flush() + + # The second plugin should still have been called + self.assertIn("operation_start:op-1", tracking_plugin.calls) + + def test_multiple_plugins_all_called(self): + p1 = _TrackingPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + executor.execute_plugins(info) + executor.flush() + + self.assertIn("operation_start:op-1", p1.calls) + self.assertIn("operation_start:op-1", p2.calls) + + +class TestPluginExecutorOnInvocationStart(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_start.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def tearDown(self): + self.executor.flush() + + def _make_context(self, request_id="req-123"): + ctx = MagicMock() + ctx.aws_request_id = request_id + return ctx + + def _make_operation(self, start_timestamp=None): + op = MagicMock() + op.start_timestamp = start_timestamp or self.ts + return op + + def test_first_invocation_fires_execution_start_and_invocation_start(self): + ctx = self._make_context() + op = self._make_operation() + + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + self.executor.flush() + + # ExecutionStartInfo dispatches to on_invocation_start in match + # InvocationStartInfo dispatches to on_invocation_start in match + # So we expect two invocation_start calls + invocation_calls = [ + c + for c in self.plugin.calls + if c.startswith("invocation_start") or c.startswith("execution_start") + ] + self.assertEqual(len(invocation_calls), 2) + + def test_replay_invocation_skips_execution_start(self): + ctx = self._make_context() + op = self._make_operation() + + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=True, + ) + self.executor.flush() + + # Only InvocationStartInfo should be dispatched (not ExecutionStartInfo) + invocation_calls = [ + c + for c in self.plugin.calls + if c.startswith("invocation_start") or c.startswith("execution_start") + ] + self.assertEqual(len(invocation_calls), 1) + + def test_none_context_uses_none_request_id(self): + op = self._make_operation() + + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=None, + execution_operation=op, + is_replaying=False, + ) + self.executor.flush() + + invocation_calls = [ + c + for c in self.plugin.calls + if c.startswith("invocation_start") or c.startswith("execution_start") + ] + # Both ExecutionStartInfo and InvocationStartInfo dispatched + self.assertEqual(len(invocation_calls), 2) + # request_id should be None + self.assertIn("invocation_start:None", self.plugin.calls) + + +class TestPluginExecutorOnInvocationEnd(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_end.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def tearDown(self): + self.executor.flush() + + def _make_context(self, request_id="req-123"): + ctx = MagicMock() + ctx.aws_request_id = request_id + return ctx + + def _make_operation(self, start_ts=None, end_ts=None): + op = MagicMock() + op.start_timestamp = start_ts or self.ts + op.end_timestamp = end_ts + return op + + def test_succeeded_fires_invocation_end_and_execution_end(self): + ctx = self._make_context() + op = self._make_operation(end_ts=self.ts) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + self.executor.on_invocation_end( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + output=output, + ) + self.executor.flush() + + self.assertIn("invocation_end:req-123", self.plugin.calls) + self.assertIn("execution_end:req-123", self.plugin.calls) + + def test_failed_fires_invocation_end_and_execution_end(self): + ctx = self._make_context() + op = self._make_operation(end_ts=self.ts) + err = ErrorObject(message="oops", type="Error", data=None, stack_trace=None) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, result=None, error=err + ) + + self.executor.on_invocation_end( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + output=output, + ) + self.executor.flush() + + self.assertIn("invocation_end:req-123", self.plugin.calls) + self.assertIn("execution_end:req-123", self.plugin.calls) + + def test_pending_fires_only_invocation_end(self): + ctx = self._make_context() + op = self._make_operation(end_ts=self.ts) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.PENDING, result=None, error=None + ) + + self.executor.on_invocation_end( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + output=output, + ) + self.executor.flush() + + self.assertIn("invocation_end:req-123", self.plugin.calls) + execution_end_calls = [ + c for c in self.plugin.calls if c.startswith("execution_end") + ] + self.assertEqual(len(execution_end_calls), 0) + + def test_none_execution_operation_uses_now_for_end_timestamp(self): + ctx = self._make_context() + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + with patch("aws_durable_execution_sdk_python.plugin.datetime") as mock_dt: + mock_dt.datetime.now.return_value = self.ts + self.executor.on_invocation_end( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=None, + output=output, + ) + self.executor.flush() + + self.assertIn("invocation_end:req-123", self.plugin.calls) + + def test_none_end_timestamp_on_operation_uses_now(self): + ctx = self._make_context() + op = self._make_operation(end_ts=None) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + self.executor.on_invocation_end( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + output=output, + ) + self.executor.flush() + + self.assertIn("invocation_end:req-123", self.plugin.calls) + + +class TestPluginExecutorOnOperationAction(unittest.TestCase): + """Tests for PluginExecutor.on_operation_action.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def tearDown(self): + self.executor.flush() + + def test_start_action_fires_operation_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + self.executor.on_operation_action(None, update) + self.executor.flush() + + self.assertIn("operation_start:op-1", self.plugin.calls) + + def test_start_action_for_step_fires_attempt_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + self.executor.on_operation_action(None, update) + self.executor.flush() + + self.assertIn("attempt_start:op-1", self.plugin.calls) + + def test_start_action_for_step_with_existing_operation_uses_attempt(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + operation = MagicMock() + operation.step_details = MagicMock() + operation.step_details.attempt = 3 + + self.executor.on_operation_action(operation, update) + self.executor.flush() + + self.assertIn("attempt_start:op-1", self.plugin.calls) + + def test_start_action_for_non_step_does_not_fire_attempt_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.WAIT + update.sub_type = OperationSubType.WAIT + update.name = "my-wait" + update.parent_id = "parent-1" + + self.executor.on_operation_action(None, update) + self.executor.flush() + + self.assertIn("operation_start:op-1", self.plugin.calls) + attempt_calls = [c for c in self.plugin.calls if c.startswith("attempt")] + self.assertEqual(len(attempt_calls), 0) + + def test_non_start_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.SUCCEED + update.operation_id = "op-1" + + self.executor.on_operation_action(None, update) + + self.assertEqual(self.plugin.calls, []) + + def test_fail_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.FAIL + update.operation_id = "op-1" + + self.executor.on_operation_action(None, update) + + self.assertEqual(self.plugin.calls, []) + + +class TestPluginExecutorOnOperationUpdate(unittest.TestCase): + """Tests for PluginExecutor.on_operation_update.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def tearDown(self): + self.executor.flush() + + def _make_operation( + self, + status=OperationStatus.SUCCEEDED, + step_details=None, + callback_details=None, + chained_invoke_details=None, + context_details=None, + ): + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = "parent-1" + op.start_timestamp = self.ts + op.end_timestamp = self.ts + op.status = status + op.step_details = step_details + op.callback_details = callback_details + op.chained_invoke_details = chained_invoke_details + op.context_details = context_details + return op + + def test_terminal_status_with_step_details_fires_attempt_and_operation(self): + step_details = MagicMock() + step_details.attempt = 2 + step_details.error = None + op = self._make_operation( + status=OperationStatus.SUCCEEDED, step_details=step_details + ) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertIn("attempt_end:op-1", self.plugin.calls) + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_pending_status_with_step_details_fires_attempt_only(self): + step_details = MagicMock() + step_details.attempt = 1 + step_details.error = ErrorObject( + message="retry", type="Error", data=None, stack_trace=None + ) + op = self._make_operation( + status=OperationStatus.PENDING, step_details=step_details + ) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertIn("attempt_end:op-1", self.plugin.calls) + # Should NOT fire operation_end for PENDING + operation_end_calls = [ + c for c in self.plugin.calls if c.startswith("operation_end") + ] + self.assertEqual(len(operation_end_calls), 0) + + def test_terminal_status_without_step_details_fires_operation_only(self): + op = self._make_operation(status=OperationStatus.FAILED, step_details=None) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertIn("operation_end:op-1", self.plugin.calls) + attempt_calls = [c for c in self.plugin.calls if c.startswith("attempt")] + self.assertEqual(len(attempt_calls), 0) + + def test_non_terminal_status_without_step_details_fires_nothing(self): + op = self._make_operation(status=OperationStatus.STARTED, step_details=None) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertEqual(self.plugin.calls, []) + + def test_ready_status_fires_nothing(self): + op = self._make_operation(status=OperationStatus.READY, step_details=None) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertEqual(self.plugin.calls, []) + + def test_timed_out_is_terminal(self): + op = self._make_operation(status=OperationStatus.TIMED_OUT, step_details=None) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_cancelled_is_terminal(self): + op = self._make_operation(status=OperationStatus.CANCELLED, step_details=None) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_stopped_is_terminal(self): + op = self._make_operation(status=OperationStatus.STOPPED, step_details=None) + + self.executor.on_operation_update(op) + self.executor.flush() + + self.assertIn("operation_end:op-1", self.plugin.calls) + + +class TestPluginExecutorExtractError(unittest.TestCase): + """Tests for PluginExecutor._extract_error static method.""" + + def _make_error(self, msg="error"): + return ErrorObject(message=msg, type="Error", data=None, stack_trace=None) + + def test_extract_error_from_step_details(self): + err = self._make_error("step error") + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = err + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "step error") + + def test_extract_error_from_callback_details(self): + err = self._make_error("callback error") + op = MagicMock() + op.step_details = None + op.callback_details = MagicMock() + op.callback_details.error = err + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "callback error") + + def test_extract_error_from_chained_invoke_details(self): + err = self._make_error("invoke error") + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = MagicMock() + op.chained_invoke_details.error = err + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "invoke error") + + def test_extract_error_from_context_details(self): + err = self._make_error("context error") + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = MagicMock() + op.context_details.error = err + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "context error") + + def test_extract_error_returns_none_when_no_error(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertIsNone(result) + + def test_extract_error_step_details_no_error(self): + """step_details exists but has no error - falls through to callback.""" + err = self._make_error("callback error") + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = None + op.callback_details = MagicMock() + op.callback_details.error = err + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "callback error") + + def test_extract_error_priority_step_over_callback(self): + """step_details error takes priority over callback error.""" + step_err = self._make_error("step error") + cb_err = self._make_error("callback error") + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = step_err + op.callback_details = MagicMock() + op.callback_details.error = cb_err + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "step error") + + +class TestPluginExecutorIsTerminalStatus(unittest.TestCase): + """Tests for PluginExecutor._is_terminal_status static method.""" + + def test_succeeded_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.SUCCEEDED)) + + def test_failed_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.FAILED)) + + def test_timed_out_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.TIMED_OUT)) + + def test_cancelled_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.CANCELLED)) + + def test_stopped_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.STOPPED)) + + def test_started_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.STARTED)) + + def test_pending_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.PENDING)) + + def test_ready_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.READY)) + + +# endregion PluginExecutor Tests + + +# region Helper Classes + + +class _NoOpPlugin(DurableExecutionPlugin): + """Concrete subclass that inherits all default no-op methods.""" + + pass + + +class _TrackingPlugin(DurableExecutionPlugin): + """Concrete subclass that tracks calls to all hooks.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + self.calls.append(f"execution_start:{info.request_id}") + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + self.calls.append(f"execution_end:{info.request_id}") + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + self.calls.append(f"invocation_start:{info.request_id}") + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + self.calls.append(f"invocation_end:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info: OperationEndInfo) -> None: + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + self.calls.append(f"attempt_end:{info.operation_id}") + + +class _FailingPlugin(DurableExecutionPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("boom") + + def on_execution_end(self, info): + raise RuntimeError("boom") + + def on_invocation_start(self, info): + raise RuntimeError("boom") + + def on_invocation_end(self, info): + raise RuntimeError("boom") + + def on_operation_start(self, info): + raise RuntimeError("boom") + + def on_operation_end(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("boom") + + +# endregion Helper Classes + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/state_test.py b/tests/state_test.py index 0152ca6c..c2055135 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -37,6 +37,10 @@ StateOutput, StepDetails, ) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + PluginExecutor, +) from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, CheckpointedResult, @@ -405,6 +409,7 @@ def test_execution_state_creation(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) assert state.durable_execution_arn == "test_arn" assert state.operations == {} @@ -425,6 +430,7 @@ def test_get_checkpoint_result_success_with_result(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -446,6 +452,7 @@ def test_get_checkpoint_result_success_without_step_details(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -467,6 +474,7 @@ def test_get_checkpoint_result_operation_not_succeeded(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -483,6 +491,7 @@ def test_get_checkpoint_result_operation_not_found(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("nonexistent") @@ -500,6 +509,7 @@ def test_create_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -530,6 +540,7 @@ def test_create_checkpoint_with_none(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with None and is_sync=False enqueues an empty checkpoint @@ -554,6 +565,7 @@ def test_create_checkpoint_with_no_args(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with no args and is_sync=False enqueues an empty checkpoint @@ -582,6 +594,7 @@ def test_get_checkpoint_result_started(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -675,6 +688,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) state.fetch_paginated_operations( @@ -773,6 +787,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -811,6 +826,7 @@ def test_fetch_paginated_operations_logs_error(caplog): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -920,6 +936,7 @@ def test_checkpoint_batch_respects_default_max_items_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -988,6 +1005,7 @@ def test_collect_checkpoint_batch_respects_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1021,6 +1039,7 @@ def test_collect_checkpoint_batch_uses_overflow_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Put operations in overflow queue @@ -1072,6 +1091,7 @@ def test_collect_checkpoint_batch_handles_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue empty checkpoint @@ -1107,6 +1127,7 @@ def test_collect_checkpoint_batch_returns_empty_when_stopped(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal stop before collecting @@ -1128,6 +1149,7 @@ def test_parent_child_relationship_building(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1169,6 +1191,7 @@ def test_descendant_cancellation_when_parent_completes(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1208,6 +1231,7 @@ def test_rejection_of_operations_from_completed_parents(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1257,6 +1281,7 @@ def test_nested_parallel_operations_deep_hierarchy(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build deep hierarchy: grandparent -> parent -> child @@ -1313,6 +1338,7 @@ def test_synchronous_checkpoint_blocks_until_complete(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -1361,6 +1387,7 @@ def test_concurrent_access_to_operations_dictionary(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add initial operation @@ -1430,6 +1457,7 @@ def test_stop_checkpointing_signals_background_thread(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Verify event is not set initially @@ -1523,6 +1551,7 @@ def test_create_checkpoint_sync_with_parent_id(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1574,6 +1603,7 @@ def test_create_checkpoint_sync_rejects_orphaned_operation(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child relationship @@ -1638,6 +1668,7 @@ def test_mark_orphans_handles_cycles(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Manually create a cycle (shouldn't happen in practice, but test defensive code) @@ -1668,6 +1699,7 @@ def test_checkpoint_batches_forever_exception_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation @@ -1715,6 +1747,7 @@ def test_collect_checkpoint_batch_shutdown_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add operation to queue (would be a non-essential async checkpoint in practice) @@ -1744,6 +1777,7 @@ def test_collect_checkpoint_batch_shutdown_empty_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal shutdown with empty queue @@ -1771,6 +1805,7 @@ def test_collect_checkpoint_batch_overflow_put_back(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1816,6 +1851,7 @@ def test_create_checkpoint_sync_with_none_operation_update(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Simulate background processor @@ -1848,6 +1884,7 @@ def test_checkpoint_batches_forever_exception_with_no_sync_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -1887,6 +1924,7 @@ def test_collect_checkpoint_batch_size_limit_during_time_window(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1940,6 +1978,7 @@ def test_collect_checkpoint_batch_respects_max_operations_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1983,6 +2022,7 @@ def test_collect_checkpoint_batch_time_window_expires(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2030,6 +2070,7 @@ def test_collect_checkpoint_batch_empty_overflow_queue_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Ensure overflow queue is empty (it should be by default) @@ -2067,6 +2108,7 @@ def test_collect_checkpoint_batch_overflow_queue_hits_operation_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2106,6 +2148,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2155,6 +2198,7 @@ def test_checkpoint_error_signals_completion_events_with_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation with completion event @@ -2211,6 +2255,7 @@ def test_synchronous_caller_receives_error_on_background_thread_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2288,6 +2333,7 @@ def test_exception_propagates_through_threadpoolexecutor(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue an operation @@ -2321,6 +2367,7 @@ def test_multiple_sync_operations_all_remain_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create multiple synchronous operations @@ -2372,6 +2419,7 @@ def test_async_operations_not_affected_by_error_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -2409,6 +2457,7 @@ def test_mixed_sync_async_operations_only_sync_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create sync operation with completion event @@ -2469,6 +2518,7 @@ def test_create_checkpoint_accepts_is_sync_parameter(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2503,6 +2553,7 @@ def test_create_checkpoint_default_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2549,6 +2600,7 @@ def test_create_checkpoint_explicit_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2590,6 +2642,7 @@ def test_create_checkpoint_is_sync_false_no_completion_event(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2620,6 +2673,7 @@ def test_create_checkpoint_is_sync_false_returns_immediately(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2658,6 +2712,7 @@ def test_create_checkpoint_with_none_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with None (will block) @@ -2694,6 +2749,7 @@ def test_create_checkpoint_no_args_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with no arguments (will block) @@ -2733,6 +2789,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit_final(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2788,6 +2845,7 @@ def test_create_checkpoint_blocks_until_completion_default(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2859,6 +2917,7 @@ def test_create_checkpoint_blocks_until_completion_explicit_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2930,6 +2989,7 @@ def test_create_checkpoint_completion_event_created_and_signaled(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2994,6 +3054,7 @@ def test_create_checkpoint_completion_event_not_signaled_on_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3080,6 +3141,7 @@ def test_create_checkpoint_caller_remains_blocked_on_background_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3162,6 +3224,7 @@ def test_create_checkpoint_multiple_sync_calls_all_block(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) num_callers = 3 @@ -3238,6 +3301,7 @@ def test_create_checkpoint_sync_with_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Track timing and completion @@ -3296,6 +3360,7 @@ def test_create_checkpoint_sync_success(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3330,6 +3395,7 @@ def test_create_checkpoint_sync_unwraps_background_thread_error(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3363,6 +3429,7 @@ def test_create_checkpoint_sync_always_synchronous(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3400,6 +3467,7 @@ def test_state_replay_mode(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3433,6 +3501,7 @@ def test_state_replay_mode_with_timed_out(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3464,6 +3533,7 @@ def test_collect_checkpoint_batch_coalesces_many_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3497,6 +3567,7 @@ def test_collect_checkpoint_batch_empty_checkpoints_with_real_ops_respects_limit initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3536,6 +3607,7 @@ def test_collect_checkpoint_batch_overflow_coalesces_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3576,6 +3648,7 @@ def test_checkpoint_batches_forever_single_api_call_for_many_empty_checkpoints() initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3624,6 +3697,7 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3676,6 +3750,7 @@ def test_execution_state_get_execution_operation_no_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3707,6 +3782,7 @@ def test_initial_execution_state_get_execution_operation_wrong_type(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3743,8 +3819,472 @@ def test_initial_execution_state_get_input_payload_none(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) result = state.get_input_payload() assert result is None + + +# region Plugin Executor Integration Tests + + +class _RecordingPlugin(DurableExecutionPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self): + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append("execution_end") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append("invocation_end") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info): + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info): + self.calls.append(f"attempt_end:{info.operation_id}") + + +def test_execution_state_accepts_plugin_executor_parameter(): + """Test that ExecutionState can be created with a plugin_executor parameter.""" + mock_client = Mock(spec=LambdaClient) + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + assert state._plugin_executor is plugin_executor + + +def test_execution_state_defaults_plugin_executor_when_none(): + """Test that ExecutionState creates a default PluginExecutor when None is passed.""" + mock_client = Mock(spec=LambdaClient) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=None, + ) + + assert state._plugin_executor is not None + assert isinstance(state._plugin_executor, PluginExecutor) + assert state._plugin_executor.plugins == [] + + +def test_plugin_executor_on_operation_action_called_on_checkpoint(): + """Test that plugin_executor.on_operation_action is called for each update after checkpoint.""" + mock_client = Mock(spec=LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + # Start background thread + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + plugin_executor.flush() + + # on_operation_action is called for START updates + assert "operation_start:step-1" in plugin.calls + assert "attempt_start:step-1" in plugin.calls + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_on_operation_update_called_for_terminal_operations(): + """Test that plugin_executor.on_operation_update is called for terminal operations.""" + mock_client = Mock(spec=LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="my-step", + payload='"done"', + ) + state.create_checkpoint(operation_update, is_sync=True) + + # Terminal status triggers on_operation_update which fires operation_start + operation_end + # and attempt_start + attempt_end (because step_details is present) + plugin_executor.flush() + assert "operation_end:step-1" in plugin.calls + assert "attempt_end:step-1" in plugin.calls + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_not_called_for_non_terminal_operations(): + """Test that plugin_executor.on_operation_update does not fire for non-terminal operations.""" + mock_client = Mock(spec=LambdaClient) + + # Return a STARTED step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=None, + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + plugin_executor.flush() + + # on_operation_action fires for START + assert "operation_start:step-1" in plugin.calls + # But on_operation_update should NOT fire operation_end for STARTED status + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_called_for_multiple_updates_in_batch(): + """Test that plugin_executor is called for each update in a batch.""" + mock_client = Mock(spec=LambdaClient) + + # Return multiple operations from checkpoint + step_op1 = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result1"'), + ) + step_op2 = Operation( + operation_id="step-2", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result2"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op1, step_op2], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + config = CheckpointBatcherConfig( + max_batch_time_seconds=0.2, + max_batch_operations=10, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + batcher_config=config, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + op1 = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-1", + ) + op2 = OperationUpdate( + operation_id="step-2", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-2", + ) + # Enqueue both without blocking so they batch together + state.create_checkpoint(op1, is_sync=False) + state.create_checkpoint(op2, is_sync=True) + + # Both operations should have triggered on_operation_action + plugin_executor.flush() + assert "operation_start:step-1" in plugin.calls + assert "operation_start:step-2" in plugin.calls + # Both terminal operations should have triggered on_operation_update + assert "operation_end:step-1" in plugin.calls + assert "operation_end:step-2" in plugin.calls + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_not_called_on_checkpoint_failure(): + """Test that plugin_executor is NOT called when checkpoint API fails.""" + mock_client = Mock(spec=LambdaClient) + mock_client.checkpoint.side_effect = RuntimeError("API error") + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + + with pytest.raises(BackgroundThreadError): + state.create_checkpoint(operation_update, is_sync=True) + + plugin_executor.flush() + + # Plugin should NOT have been called since checkpoint failed + assert "operation_start:step-1" not in plugin.calls + assert "operation_end:step-1" not in plugin.calls + finally: + plugin_executor.close() + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_exception_does_not_break_checkpointing(): + """Test that a plugin exception does not break the checkpoint processing loop.""" + mock_client = Mock(spec=LambdaClient) + + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + class _ExplodingPlugin(DurableExecutionPlugin): + def on_operation_start(self, info): + raise RuntimeError("plugin exploded") + + def on_operation_end(self, info): + raise RuntimeError("plugin exploded") + + exploding_plugin = _ExplodingPlugin() + plugin_executor = PluginExecutor(plugins=[exploding_plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + # Should not raise even though plugin explodes + state.create_checkpoint(operation_update, is_sync=True) + + # Checkpoint should still have been called successfully + assert mock_client.checkpoint.call_count == 1 + finally: + plugin_executor.close() + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_called_for_pending_operations(): + """Test that plugin_executor.on_operation_update fires on_attempt_end for PENDING operations.""" + mock_client = Mock(spec=LambdaClient) + + # Return a PENDING step operation from checkpoint (simulates a retry scenario) + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails( + attempt=1, + result=None, + error=ErrorObject( + message="transient failure", + type="RetryableError", + data=None, + stack_trace=None, + ), + ), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + + # on_attempt_end should fire for PENDING operations with step_details + assert "attempt_end:step-1" in plugin.calls + # operation_end should NOT fire for PENDING (only for terminal statuses) + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + finally: + plugin_executor.close() + state.stop_checkpointing() + executor.shutdown(wait=True) + + +# endregion Plugin Executor Integration Tests