diff --git a/src/aws_durable_execution_sdk_python/operation/step.py b/src/aws_durable_execution_sdk_python/operation/step.py index d9719a9..8a418fb 100644 --- a/src/aws_durable_execution_sdk_python/operation/step.py +++ b/src/aws_durable_execution_sdk_python/operation/step.py @@ -152,8 +152,8 @@ def check_result_status(self) -> CheckResult[T]: ): return CheckResult.create_is_ready_to_execute(checkpointed_result) - # Create START checkpoint if not exists - if not checkpointed_result.is_existent(): + # Create START checkpoint if nonexistent or READY + if not checkpointed_result.is_existent() or checkpointed_result.is_ready(): start_operation: OperationUpdate = OperationUpdate.create_step_start( identifier=self.operation_identifier, ) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 416b395..8317550 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -175,6 +175,13 @@ def is_pending(self) -> bool: return False return op.status is OperationStatus.PENDING + def is_ready(self) -> bool: + """Return True if the checkpointed operation is READY.""" + op = self.operation + if not op: + return False + return op.status is OperationStatus.READY + def is_timed_out(self) -> bool: """Return True if the checkpointed operation is TIMED_OUT.""" op = self.operation diff --git a/tests/operation/step_test.py b/tests/operation/step_test.py index a7e38a8..75ed768 100644 --- a/tests/operation/step_test.py +++ b/tests/operation/step_test.py @@ -894,3 +894,58 @@ def test_step_executes_function_when_second_check_returns_started(): mock_state.get_checkpoint_result.call_count == 1 ) # Only one check for AT_LEAST_ONCE assert mock_state.create_checkpoint.call_count == 2 # START + SUCCEED checkpoints + + +def test_step_creates_start_checkpoint_when_status_is_ready(): + """Test that create_checkpoint is called with START action when the step is in READY status.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + + # Simulate a step that is in READY status (e.g., returned from a previous checkpoint) + ready_op = Operation( + operation_id="step_ready_1", + operation_type=OperationType.STEP, + status=OperationStatus.READY, + step_details=StepDetails(attempt=0), + ) + ready_result = CheckpointedResult.create_from_operation(ready_op) + + # After creating the sync START checkpoint, the refreshed result returns STARTED + started_op = Operation( + operation_id="step_ready_1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=StepDetails(attempt=0), + ) + started_result = CheckpointedResult.create_from_operation(started_op) + mock_state.get_checkpoint_result.side_effect = [ready_result, started_result] + + config = StepConfig(step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY) + mock_callable = Mock(return_value="ready_step_result") + mock_logger = Mock(spec=Logger) + mock_logger.with_log_info.return_value = mock_logger + + result = step_handler( + mock_callable, + mock_state, + OperationIdentifier("step_ready_1", None, "test_step"), + config, + mock_logger, + ) + + assert result == "ready_step_result" + mock_callable.assert_called_once() + + # Verify START checkpoint was created + start_call = mock_state.create_checkpoint.call_args_list[0] + start_operation = start_call[1]["operation_update"] + assert start_operation.operation_id == "step_ready_1" + assert start_operation.operation_type is OperationType.STEP + assert start_operation.sub_type is OperationSubType.STEP + assert start_operation.action is OperationAction.START + + # Verify SUCCEED checkpoint was also created after execution + assert mock_state.create_checkpoint.call_count == 2 + success_call = mock_state.create_checkpoint.call_args_list[1] + success_operation = success_call[1]["operation_update"] + assert success_operation.action is OperationAction.SUCCEED diff --git a/tests/state_test.py b/tests/state_test.py index 2ea0ab3..0152ca6 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -332,6 +332,21 @@ def test_checkpointerd_result_is_pending(): assert result_no_op.is_pending() is False +def test_checkpointerd_result_is_ready(): + """Test CheckpointedResult.is_ready method.""" + operation = Operation( + operation_id="op1", + operation_type=OperationType.STEP, + status=OperationStatus.READY, + ) + result = CheckpointedResult.create_from_operation(operation) + assert result.is_ready() is True + + # Test with no operation + result_no_op = CheckpointedResult.create_not_found() + assert result_no_op.is_ready() is False + + def test_checkpointed_result_is_started(): """Test CheckpointedResult.is_started method.""" operation = Operation(