From 128bdb1a04ded708fdc675f32794f6d746123170 Mon Sep 17 00:00:00 2001 From: Gaurav Madarkal Date: Wed, 22 Apr 2026 00:12:38 -0700 Subject: [PATCH 1/2] feat(train): Add wait_timeout parameter to train() Updated trainers: SFT, DPO, RLAIF, RLVR, and BaseTrainer. --- sagemaker-train/src/sagemaker/train/base_trainer.py | 2 +- sagemaker-train/src/sagemaker/train/dpo_trainer.py | 11 +++++++++-- sagemaker-train/src/sagemaker/train/rlaif_trainer.py | 10 ++++++++-- sagemaker-train/src/sagemaker/train/rlvr_trainer.py | 10 ++++++++-- sagemaker-train/src/sagemaker/train/sft_trainer.py | 10 ++++++++-- 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/base_trainer.py b/sagemaker-train/src/sagemaker/train/base_trainer.py index a422dc3240..d54b6450c9 100644 --- a/sagemaker-train/src/sagemaker/train/base_trainer.py +++ b/sagemaker-train/src/sagemaker/train/base_trainer.py @@ -76,6 +76,6 @@ def _is_nova_model_for_telemetry(self) -> bool: return _is_nova_model(model_name) if model_name else False @abstractmethod - def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True): + def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True, wait_timeout: Optional[int] = None): """Common training method that calls the specific implementation.""" pass diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index b9f7449354..7e8b4747c5 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -180,7 +180,8 @@ def _process_hyperparameters(self): def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, - wait: bool = True): + wait: bool = True, + wait_timeout: Optional[int] = None): """Execute the DPO training job. Parameters: @@ -192,6 +193,9 @@ def train(self, Can be an S3 URI, dataset ARN, or DataSet object. wait (bool): Whether to wait for the training job to complete. Defaults to True. + wait_timeout (Optional[int]): + Maximum time in seconds to wait for the training job to complete. Only used when wait=True. + If None, uses the default timeout from the wait utility. Returns: TrainingJob: The SageMaker training job object. @@ -276,7 +280,10 @@ def train(self, from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError try : - _wait(training_job) + wait_kwargs = {} + if wait_timeout is not None: + wait_kwargs['timeout'] = wait_timeout + _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index d4cbc7cf8f..a9136e2742 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -197,7 +197,7 @@ def _validate_reward_model_id(self, reward_model_id): @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): + def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None): """Execute the RLAIF training job. Parameters: @@ -209,6 +209,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati Can be an S3 URI, dataset ARN, or DataSet object. wait (bool): Whether to wait for the training job to complete. Defaults to True. + wait_timeout (Optional[int]): + Maximum time in seconds to wait for the training job to complete. Only used when wait=True. + If None, uses the default timeout from the wait utility. Returns: TrainingJob: The SageMaker training job object. @@ -295,7 +298,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError try : - _wait(training_job) + wait_kwargs = {} + if wait_timeout is not None: + wait_kwargs['timeout'] = wait_timeout + _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 93a5105f8e..c496222bf4 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -183,7 +183,7 @@ def _process_hyperparameters(self): @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, - validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): + validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None): """Execute the RLVR training job. Parameters: @@ -195,6 +195,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, Can be an S3 URI, dataset ARN, or DataSet object. wait (bool): Whether to wait for the training job to complete. Defaults to True. + wait_timeout (Optional[int]): + Maximum time in seconds to wait for the training job to complete. Only used when wait=True. + If None, uses the default timeout from the wait utility. Returns: TrainingJob: The SageMaker training job object. @@ -283,7 +286,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError try: - _wait(training_job) + wait_kwargs = {} + if wait_timeout is not None: + wait_kwargs['timeout'] = wait_timeout + _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index cc67469406..136231bd6f 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -180,7 +180,7 @@ def _process_hyperparameters(self): self.hyperparameters._specs.pop('validation_data_path', None) @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train") - def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True): + def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None): """Execute the SFT training job. Parameters: @@ -192,6 +192,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati Can be an S3 URI, dataset ARN, or DataSet object. wait (bool): Whether to wait for the training job to complete. Defaults to True. + wait_timeout (Optional[int]): + Maximum time in seconds to wait for the training job to complete. Only used when wait=True. + If None, uses the default timeout from the wait utility. Returns: TrainingJob: The SageMaker training job object. @@ -277,7 +280,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError try : - _wait(training_job) + wait_kwargs = {} + if wait_timeout is not None: + wait_kwargs['timeout'] = wait_timeout + _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) From 8e66bdc34e70146610e06783e5b18c32e2e37bce Mon Sep 17 00:00:00 2001 From: Gaurav Madarkal Date: Wed, 22 Apr 2026 09:41:54 -0700 Subject: [PATCH 2/2] feat(train): added unit tests for wait_timeout --- .../tests/unit/train/test_dpo_trainer.py | 129 ++++++++++++++++++ .../tests/unit/train/test_rlaif_trainer.py | 129 ++++++++++++++++++ .../tests/unit/train/test_rlvr_trainer.py | 129 ++++++++++++++++++ .../tests/unit/train/test_sft_trainer.py | 129 ++++++++++++++++++ 4 files changed, 516 insertions(+) diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 93a4b18fa9..2d5cf2246a 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -378,3 +378,132 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): assert trainer.stopping_condition == stopping_condition assert trainer.stopping_condition.max_runtime_in_seconds == 14400 + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_serverless_config, mock_output_config, + mock_convert_channels, mock_input_config, mock_validate_group, + mock_unique_name, mock_get_sagemaker_session, mock_get_role, + mock_get_options, mock_resolve_model, mock_wait): + """Test that wait_timeout is passed to _wait as timeout kwarg.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True, wait_timeout=600) + + mock_wait.assert_called_once_with(mock_training_job, timeout=600) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_serverless_config, mock_output_config, + mock_convert_channels, mock_input_config, mock_validate_group, + mock_unique_name, mock_get_sagemaker_session, mock_get_role, + mock_get_options, mock_resolve_model, mock_wait): + """Test that _wait is called without timeout kwarg when wait_timeout is None.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True) + + mock_wait.assert_called_once_with(mock_training_job) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_serverless_config, mock_output_config, + mock_convert_channels, mock_input_config, mock_validate_group, + mock_unique_name, mock_get_sagemaker_session, mock_get_role, + mock_get_options, mock_resolve_model, mock_wait): + """Test that _wait is not called when wait=False.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=False, wait_timeout=600) + + mock_wait.assert_not_called() \ No newline at end of file diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index be8b9b96b6..24448ebbe6 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -554,3 +554,132 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): assert trainer.stopping_condition == stopping_condition assert trainer.stopping_condition.max_runtime_in_seconds == 86400 + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that wait_timeout is passed to _wait as timeout kwarg.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True, wait_timeout=600) + + mock_wait.assert_called_once_with(mock_training_job, timeout=600) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that _wait is called without timeout kwarg when wait_timeout is None.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True) + + mock_wait.assert_called_once_with(mock_training_job) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that _wait is not called when wait=False.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=False, wait_timeout=600) + + mock_wait.assert_not_called() \ No newline at end of file diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 4ee785285e..e16c5c1c69 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -381,3 +381,132 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): assert trainer.stopping_condition == stopping_condition assert trainer.stopping_condition.max_runtime_in_seconds == 259200 + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that wait_timeout is passed to _wait as timeout kwarg.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True, wait_timeout=600) + + mock_wait.assert_called_once_with(mock_training_job, timeout=600) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that _wait is called without timeout kwarg when wait_timeout is None.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True) + + mock_wait.assert_called_once_with(mock_training_job) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that _wait is not called when wait=False.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=False, wait_timeout=600) + + mock_wait.assert_not_called() \ No newline at end of file diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 6af829e1a7..a2473ebfd0 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -392,3 +392,132 @@ def test_default_stopping_condition_is_none(self, mock_finetuning, mock_validate trainer = SFTTrainer(model="test-model", model_package_group="test-group") assert trainer.stopping_condition is None + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that wait_timeout is passed to _wait as timeout kwarg.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True, wait_timeout=600) + + mock_wait.assert_called_once_with(mock_training_job, timeout=600) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that _wait is called without timeout kwarg when wait_timeout is None.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=True) + + mock_wait.assert_called_once_with(mock_training_job) + + @patch('sagemaker.train.common_utils.trainer_wait.wait') + @patch('sagemaker.train.common_utils.finetune_utils._get_beta_session') + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_package_config, + mock_mlflow_config, mock_output_config, mock_convert_channels, + mock_input_config, mock_validate_group, mock_unique_name, + mock_get_sagemaker_session, mock_get_role, mock_get_options, + mock_resolve_model, mock_get_session, mock_wait): + """Test that _wait is not called when wait=False.""" + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_session.return_value = Mock() + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job" + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") + trainer.train(wait=False, wait_timeout=600) + + mock_wait.assert_not_called() \ No newline at end of file