Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
389 changes: 220 additions & 169 deletions src/aws_durable_execution_sdk_python/execution.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions src/aws_durable_execution_sdk_python/lambda_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ class OperationSubType(Enum):
CHAINED_INVOKE = "ChainedInvoke"


class InvocationStatus(Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
PENDING = "PENDING"

# Used internally only: the invocation failed and the backend will retry
RETRY = "RETRY"


@dataclass(frozen=True)
class ExecutionDetails:
input_payload: str | None = None
Expand Down
351 changes: 351 additions & 0 deletions src/aws_durable_execution_sdk_python/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
import datetime
import logging
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING

from aws_durable_execution_sdk_python.lambda_service import (
OperationType,
OperationStatus,
OperationAction,
OperationSubType,
ErrorObject,
InvocationStatus,
Operation,
OperationUpdate,
)
from aws_durable_execution_sdk_python.types import LambdaContext

if TYPE_CHECKING:
from aws_durable_execution_sdk_python.execution import (
DurableExecutionInvocationOutput,
)

logger = logging.getLogger(__name__)


@dataclass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these can be frozen?

class OperationStartInfo:
operation_id: str
operation_type: OperationType
sub_type: OperationSubType | None = None
name: str | None = None
parent_id: str | None = None
start_timestamp: datetime.datetime | None = None


@dataclass
class OperationEndInfo(OperationStartInfo):
status: OperationStatus = OperationStatus.SUCCEEDED
end_timestamp: datetime.datetime | None = None
attempt: int | None = None
Comment thread
SilanHe marked this conversation as resolved.
error: ErrorObject | None = None


@dataclass
class AttemptStartInfo(OperationStartInfo):
attempt: int = 1


@dataclass
class AttemptEndInfo(AttemptStartInfo):
succeeded: bool | None = None
Comment thread
zhongkechen marked this conversation as resolved.
end_timestamp: datetime.datetime | None = None
error: ErrorObject | None = None
next_attempt_delay_seconds: int | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't get populated currently?



@dataclass
class InvocationStartInfo:
request_id: str | None
execution_arn: str | None
start_timestamp: datetime.datetime | None


@dataclass
class InvocationEndInfo(InvocationStartInfo):
status: InvocationStatus = InvocationStatus.SUCCEEDED
end_timestamp: datetime.datetime | None = None
error: ErrorObject | None = None


@dataclass
class ExecutionStartInfo(InvocationStartInfo):
pass


@dataclass
class ExecutionEndInfo(ExecutionStartInfo):
status: InvocationStatus = InvocationStatus.SUCCEEDED
end_timestamp: datetime.datetime | None = None
error: ErrorObject | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

js has more fields....

// aws-durable-execution-sdk-js/src/types/plugin.ts
export interface ExecutionEndInfo extends InvocationInfo {
    status: "SUCCEEDED" | "FAILED";
    executionResult?: unknown;
    executionError?: Error;
    executionInput: unknown;
    operations: Record<string, Operation>;
}

If the idea of including these is so the execution span's attributes can include aggregate counts ("3 steps", "1 retry") without the plugin maintaining its own parallel state across events?



class DurableExecutionPlugin(ABC):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does context extractor work?

there is no context_extractor parameter on @durable_execution or here.

so how will the plugins get the values for traceid/parent? is the idea to use the $ENVs, but the context_extractor interface makes me think that's not the idea?

https://github.com/aws/aws-durable-execution-sdk-js/blob/otel-instrumentation-design-v2/packages/aws-durable-execution-sdk-js/src/documents/PYTHON_SDK_OTEL_PLUGIN.md#context-propagation--configurable-context_extractor

this might be unimplemented in the current JS implementation also, we should align and disambiguate the plan here.

maybe?

  1. Add lambda_context: LambdaContext | None to InvocationStartInfo 2. Add context_extractor parameter to @durable_execution
  2. Call context_extractor in the executor before on_invocation_start fires and attach/detach the returned OTel Context around the invocation

the extractor must receive lambda_context and fire per-invocation.

"""Base class for plugins. Override only the methods you need."""

def on_execution_start(self, info: ExecutionStartInfo) -> None:
pass

def on_execution_end(self, info: ExecutionEndInfo) -> None:
pass

def on_invocation_start(self, info: InvocationStartInfo) -> None:
pass

def on_invocation_end(self, info: InvocationEndInfo) -> None:
pass

def on_operation_start(self, info: OperationStartInfo) -> None:
pass

def on_operation_end(self, info: OperationEndInfo) -> None:
pass

def on_operation_attempt_start(self, info: AttemptStartInfo) -> None:
pass

def on_operation_attempt_end(self, info: AttemptEndInfo) -> None:
pass

# Todo: further discussions required to finalize the following interface
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. Also not sure whether the ABC is actually necessary here, unless there's going to be at least 1 abstractmethod.

# def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: create tracking issue



class PluginExecutor:
_DEFAULT_MAX_WORKERS = 4

def __init__(
self,
plugins: list[DurableExecutionPlugin] | None,
max_workers: int | None = None,
):
self.plugins = plugins or []
self._pending_futures: list = []
self._executor: ThreadPoolExecutor | None = (
ThreadPoolExecutor(
max_workers=max_workers or self._DEFAULT_MAX_WORKERS,
thread_name_prefix="plugin-executor",
)
if self.plugins
else None
)

def close(self) -> None:
"""Shut down the thread pool, waiting for pending tasks to complete."""
if self._executor is None:
return
self.flush()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't the next line shutdown(wait=True) effectively to the flush?

https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor.shutdown

if wait is True then this method will not return until all the pending futures are done executing and the resources associated with the executor have been freed. .. Regardless of the value of wait, the entire Python program will not exit until all pending futures are done executing.

self._executor.shutdown(wait=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this will wait to flush all before closing.

maybe make clear on the contract somewhere for DurableExecutionPlugin (or on the decorator) so consumers know when or if they have to flush themselves or not.

so on_invocation_end and on_execution_end are guaranteed to complete before the handler returns? and operation-level hooks (on_operation_*, on_operation_attempt_*) are dispatched asynchronously?


def flush(self) -> None:
"""Wait for all pending plugin tasks to complete. Useful for testing."""
for future in self._pending_futures:
future.result()
self._pending_futures.clear()

def _dispatch_plugin(self, plugin: DurableExecutionPlugin, info) -> None:
"""Invoke the appropriate plugin callback. Runs inside the thread pool."""
try:
match info:
case ExecutionEndInfo():
plugin.on_execution_end(info)
case InvocationEndInfo():
plugin.on_invocation_end(info)
case ExecutionStartInfo():
plugin.on_execution_start(info)
case InvocationStartInfo():
plugin.on_invocation_start(info)
case AttemptEndInfo():
plugin.on_operation_attempt_end(info)
case OperationEndInfo():
plugin.on_operation_end(info)
case AttemptStartInfo():
plugin.on_operation_attempt_start(info)
case OperationStartInfo():
plugin.on_operation_start(info)
case _:
raise ValueError(f"Unknown info type: {type(info)}")
except Exception:
# log and ignore the exception
logger.exception("Plugin %s exception ignored", plugin.__class__.__name__)

def execute_plugins(self, info):
if not self.plugins:
return
for plugin in self.plugins:
future = self._executor.submit(self._dispatch_plugin, plugin, info)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means strictly speaking plugins could execute out of order.

at present I don't think the design covers ordering or preserving it explicitly? But I'm thinking specifically for parent -> child relationships, order might well be important here?

It looks like the js runs synchronously on the handler thread (i.e sequentially) https://github.com/aws/aws-durable-execution-sdk-js/blob/otel-instrumentation-design-v2/packages/aws-durable-execution-sdk-js/src/utils/plugin/plugin-runner.ts

i.e js preserves ordering

self._pending_futures.append(future)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this need a lock?

_pending_futures mutates from the main thread (for on_invocation_start/end), and on the background thread (on_operation_action etc.), and flush() iterates + clear non atomically.

but I'm also not sure to what degree this goes away depending on where we end up for synchronous/ordering, is the tracking list necessary?


def on_invocation_start(
self,
durable_execution_arn: str,
context: LambdaContext | None,
execution_operation: Operation | None,
is_replaying: bool,
) -> None:
aws_request_id = context.aws_request_id if context else None
start_timestamp = (
execution_operation.start_timestamp if execution_operation else None
)

if not is_replaying:
self.execute_plugins(
ExecutionStartInfo(
request_id=aws_request_id,
execution_arn=durable_execution_arn,
start_timestamp=start_timestamp,
)
)

self.execute_plugins(
InvocationStartInfo(
request_id=aws_request_id,
execution_arn=durable_execution_arn,
start_timestamp=start_timestamp,
)
)

def on_invocation_end(
self,
durable_execution_arn: str | None,
context: LambdaContext,
execution_operation: Operation | None,
output: "DurableExecutionInvocationOutput",
) -> None:
start_timestamp = (
execution_operation.start_timestamp if execution_operation else None
)
# the actual end timestamp may be unknown because it's not checkpointed yet
end_timestamp: datetime.datetime = (
execution_operation.end_timestamp if execution_operation else None
) or datetime.datetime.now()
Comment thread
zhongkechen marked this conversation as resolved.
request_id = context.aws_request_id if context else None

self.execute_plugins(
InvocationEndInfo(
request_id=request_id,
execution_arn=durable_execution_arn,
start_timestamp=start_timestamp,
status=output.status,
end_timestamp=end_timestamp,
error=output.error,
)
)

if output.status in [InvocationStatus.SUCCEEDED, InvocationStatus.FAILED]:
self.execute_plugins(
ExecutionEndInfo(
request_id=request_id,
execution_arn=durable_execution_arn,
start_timestamp=start_timestamp,
status=output.status,
end_timestamp=end_timestamp,
error=output.error,
)
)

def on_operation_action(self, operation: Operation | None, update: OperationUpdate):
"""Execute any registered plugins for a given operation before it is updated.

Args:
update: the operation update that is pending checkpoint
"""
if update.action is not OperationAction.START:
return

self.execute_plugins(
OperationStartInfo(
operation_id=update.operation_id,
operation_type=update.operation_type,
sub_type=update.sub_type,
name=update.name,
parent_id=update.parent_id,
start_timestamp=datetime.datetime.now(),
)
)

if update.operation_type is OperationType.STEP:
attempt = (
operation.step_details.attempt
if operation and operation.step_details
else 1
)
self.execute_plugins(
AttemptStartInfo(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an OperationStartInfo and a AttemptStartInfo for each attempt?

from #370

  • onOperationStart, onOperationEnd, onOperationAttemptStart,
    onOperationAttemptEnd fire at most once per operation/attempt.
  • onOperationAttemptStart and onOperationAttemptEnd will manage a span that is a child of the span managed by onOperationStart and onOperationEnd. That is retry attempts for steps and wait-for-condition operations.

I could easily be mistracing this, is this what happens? Maybe it's an artifact of the previous PR #381 introducing the READY logic so every step retry checkpoints START?

  1. operation_start (attempt 1 START)
  2. attempt_start=1 (attempt 1 START)
  3. attempt_end=1 (RETRY → PENDING)
  4. operation_start (attempt 2 START, across invocations)
  5. attempt_start (attempt 2 START)
  6. attempt_end=2 (SUCCEED)
  7. operation_end

might be an idea to introduce a functional test to verify/confirm the behaviour.

operation_id=update.operation_id,
operation_type=update.operation_type,
sub_type=update.sub_type,
name=update.name,
parent_id=update.parent_id,
start_timestamp=datetime.datetime.now(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utc

attempt=attempt,
)
)

def on_operation_update(self, operation):
"""Execute any registered plugins for a given operation after it is updated.

Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be
checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED).

Args:
operation: the operation is just checkpointed
"""
params = dict(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_timestamp=operation.start_timestamp,
)
if operation.step_details and (
self._is_terminal_status(operation.status)
# PENDING in addition to terminal status
or operation.status is OperationStatus.PENDING
):
self.execute_plugins(
AttemptEndInfo(
**params,
end_timestamp=operation.end_timestamp,
attempt=operation.step_details.attempt,
succeeded=operation.status is OperationStatus.SUCCEEDED,
error=operation.step_details.error,
)
)

if self._is_terminal_status(operation.status):
attempt = operation.step_details.attempt if operation.step_details else None
self.execute_plugins(
OperationEndInfo(
**params,
end_timestamp=operation.end_timestamp,
status=operation.status,
error=self._extract_error(operation),
attempt=attempt,
)
)

@staticmethod
def _extract_error(operation: Operation):
if operation.step_details and operation.step_details.error:
return operation.step_details.error
if operation.callback_details and operation.callback_details.error:
return operation.callback_details.error
if operation.chained_invoke_details and operation.chained_invoke_details.error:
return operation.chained_invoke_details.error
if operation.context_details and operation.context_details.error:
return operation.context_details.error
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In state, CheckpointedResult.create_from_operation(cls, operation) already does something of the sort, so this here is replicating logic. The existing logic extracts both result and error).

but state imports plugin and we'd like to avoid circular reference.

so we should probably refactor and move this existing logic so everything can use it.

it might be an idea to do this as a no-func change refactor seperately in a preceding PR first, but (I'm typing quickly here just to give the ide) something like:

   @dataclass(frozen=True)
   class Operation:
       ...

       def get_error(self) -> ErrorObject | None:
           match self.operation_type:
               case OperationType.STEP:
                   return self.step_details.error if self.step_details else None
               case OperationType.CALLBACK:
                   return self.callback_details.error if self.callback_details else None
               case OperationType.CHAINED_INVOKE:
                   return self.chained_invoke_details.error if self.chained_invoke_details else None
               case OperationType.CONTEXT:
                   return self.context_details.error if self.context_details
               else None
                   return None

       def get_result(self) -> str | None:
           ... 

and then in CheckpointedResult:

   return cls(
       operation=operation,
       status=operation.status,
       result=operation.get_result(),
       error=operation.get_error(),
   )

and here in plugin:

   OperationEndInfo(
       **params,
       ...,
       error=operation.get_error(),
   )


@staticmethod
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than have this ambient bit of unattached logic, we can keep it with the other is_succeeded, is_failed logic on the Operation object, so we have

   @dataclass(frozen=True)
   class Operation:
       _TERMINAL_STATUSES: ClassVar[frozenset[OperationStatus]] = frozenset({
           OperationStatus.SUCCEEDED,
           OperationStatus.FAILED,
           OperationStatus.CANCELLED,
           OperationStatus.TIMED_OUT,
           OperationStatus.STOPPED,
       })

       # ... fields ...

       def is_terminal(self) -> bool:
           """Return True if this operation is in a terminal state."""
           return self.status in Operation._TERMINAL_STATUSES

or maybe on the enum itself, but I'm not super fond of the forward reference lol:

  class OperationStatus(Enum):
      STARTED = "STARTED"
      PENDING = "PENDING"
      READY = "READY"
      SUCCEEDED = "SUCCEEDED"
      FAILED = "FAILED"
      CANCELLED = "CANCELLED"
      TIMED_OUT = "TIMED_OUT"
      STOPPED = "STOPPED"

      def is_terminal(self) -> bool:
          return self in _TERMINAL_STATUSES # this would be a fwd ref

def _is_terminal_status(status):
return status in [
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.TIMED_OUT,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
]
Loading
Loading