diff --git a/sagemaker-train/src/sagemaker/train/common_utils/mlflow_url_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/mlflow_url_utils.py new file mode 100644 index 0000000000..b39dab31c1 --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/common_utils/mlflow_url_utils.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Shared MLflow presigned URL utilities.""" + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +def get_presigned_mlflow_experiment_url( + mlflow_resource_arn: str, + mlflow_experiment_name: Optional[str] = None, +) -> Optional[str]: + """Generate a presigned MLflow URL, optionally deep-linked to an experiment. + + Args: + mlflow_resource_arn: MLflow tracking server or app ARN. + mlflow_experiment_name: Optional experiment name for deep-linking. + + Returns: + Presigned URL with experiment fragment, or base URL, or None on failure. + """ + try: + from sagemaker.core.utils.utils import SageMakerClient + + sm_client = SageMakerClient().sagemaker_client + response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_resource_arn) + base_url = response.get("AuthorizedUrl") + if not base_url: + return None + + if mlflow_experiment_name: + try: + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(mlflow_resource_arn) + experiment = MlflowClient( + tracking_uri=mlflow_resource_arn + ).get_experiment_by_name(mlflow_experiment_name) + if experiment: + return f"{base_url}#/experiments/{experiment.experiment_id}" + except Exception as e: + logger.debug(f"Failed to resolve MLflow experiment '{mlflow_experiment_name}': {e}") + + return base_url + except Exception as e: + logger.debug(f"Failed to generate MLflow experiment URL: {e}") + return None diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 59adcdfbfc..a7b41b5b1e 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -182,35 +182,17 @@ def get_mlflow_url(training_job) -> str: if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config): raise ValueError("Training job does not have MLflow configured") - import os - from mlflow.tracking import MlflowClient - import mlflow - from sagemaker.core.utils.utils import SageMakerClient + from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url mlflow_arn = training_job.mlflow_config.mlflow_resource_arn exp_name = training_job.mlflow_config.mlflow_experiment_name + if _is_unassigned_attribute(exp_name): + exp_name = None - # Get presigned base URL - sm_client = SageMakerClient().sagemaker_client - response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_arn) - base_url = response.get('AuthorizedUrl') - - # Try to get experiment ID and append to URL - try: - os.environ['MLFLOW_TRACKING_URI'] = mlflow_arn - mlflow.set_tracking_uri(mlflow_arn) - - mlflow_client = MlflowClient(tracking_uri=mlflow_arn) - experiment = mlflow_client.get_experiment_by_name(exp_name) - - if experiment: - # Format: base_url#/experiments/{id} - # The base_url already has /auth?authToken=... - return f"{base_url}#/experiments/{experiment.experiment_id}" - except Exception: - pass - - return base_url + url = get_presigned_mlflow_experiment_url(mlflow_arn, exp_name) + if url is None: + raise ValueError("Failed to generate presigned MLflow URL") + return url diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index d5e50f86b5..496e56e9cb 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -920,6 +920,28 @@ def wait( # Create console with Jupyter support console = Console(force_jupyter=True) + # MLflow link caching (presigned URLs expire after 5 min) + mlflow_link_cache = {'url': None, 'timestamp': 0} + + def get_cached_mlflow_url(): + """Get cached MLflow URL, regenerating every 4 minutes.""" + from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute + from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url + + current_time = time.time() + if mlflow_link_cache['url'] is None or (current_time - mlflow_link_cache['timestamp']) > 240: + pe = self._pipeline_execution + mlflow_cfg = getattr(pe, 'm_lflow_config', None) if pe else None + if mlflow_cfg and not _is_unassigned_attribute(mlflow_cfg): + arn = getattr(mlflow_cfg, 'mlflow_resource_arn', None) + if arn and not _is_unassigned_attribute(arn): + exp_name = getattr(mlflow_cfg, 'mlflow_experiment_name', None) + if exp_name and _is_unassigned_attribute(exp_name): + exp_name = None + mlflow_link_cache['url'] = get_presigned_mlflow_experiment_url(arn, exp_name) + mlflow_link_cache['timestamp'] = current_time + return mlflow_link_cache['url'] + while True: clear_output(wait=True) self.refresh() @@ -960,6 +982,10 @@ def wait( links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]") except Exception: pass + # Add MLflow experiment link if available + cached_mlflow_url = get_cached_mlflow_url() + if cached_mlflow_url: + links.append(f"[bright_blue underline][link={cached_mlflow_url}]🔗 MLflow Experiment[/link][/bright_blue underline]") if links: header_table.add_row("Links", " | ".join(links)) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait.py b/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait.py index 7bcff8fa0c..3a2cdd031f 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_trainer_wait.py @@ -25,6 +25,7 @@ _is_unassigned_attribute, _calculate_training_progress, _calculate_transition_duration, + get_mlflow_url, wait ) @@ -489,3 +490,66 @@ def test_wait_metrics_exception_non_jupyter(self, mock_is_jupyter, mock_setup_ml # Should complete successfully despite metrics exception training_job.refresh.assert_called() + + +class TestGetMlflowUrl: + """Test cases for get_mlflow_url function.""" + + @patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url") + def test_delegates_to_shared_helper(self, mock_helper): + """Test that get_mlflow_url extracts config and delegates to shared helper.""" + mock_helper.return_value = "https://mlflow.example.com/auth?token=abc#/experiments/42" + + training_job = MagicMock() + training_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test" + training_job.mlflow_config.mlflow_experiment_name = "my-experiment" + + result = get_mlflow_url(training_job) + + mock_helper.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123:mlflow-app/test", + "my-experiment", + ) + assert result == "https://mlflow.example.com/auth?token=abc#/experiments/42" + + @patch("sagemaker.train.common_utils.trainer_wait.TrainingJob") + @patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url") + def test_accepts_job_name_string(self, mock_helper, mock_tj_class): + """Test that a string job name is resolved via TrainingJob.get().""" + mock_helper.return_value = "https://mlflow.example.com/auth" + mock_tj = MagicMock() + mock_tj.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test" + mock_tj.mlflow_config.mlflow_experiment_name = None + mock_tj_class.get.return_value = mock_tj + + result = get_mlflow_url("my-training-job") + + mock_tj_class.get.assert_called_once_with(training_job_name="my-training-job") + assert result == "https://mlflow.example.com/auth" + + def test_raises_when_no_mlflow_config(self): + """Test raises ValueError when training job has no mlflow config.""" + training_job = MagicMock() + training_job.mlflow_config = MockUnassignedAttribute() + + with pytest.raises(ValueError, match="does not have MLflow configured"): + get_mlflow_url(training_job) + + def test_raises_when_mlflow_config_missing(self): + """Test raises ValueError when training job lacks mlflow_config attribute.""" + training_job = MagicMock(spec=[]) # no attributes + + with pytest.raises(ValueError, match="does not have MLflow configured"): + get_mlflow_url(training_job) + + @patch("sagemaker.train.common_utils.mlflow_url_utils.get_presigned_mlflow_experiment_url") + def test_raises_when_helper_returns_none(self, mock_helper): + """Test raises ValueError when presigned URL generation fails.""" + mock_helper.return_value = None + + training_job = MagicMock() + training_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-app/test" + training_job.mlflow_config.mlflow_experiment_name = "exp" + + with pytest.raises(ValueError, match="Failed to generate presigned MLflow URL"): + get_mlflow_url(training_job) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_execution.py b/sagemaker-train/tests/unit/train/evaluate/test_execution.py index 89118b4b2e..5883207f32 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_execution.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_execution.py @@ -38,6 +38,7 @@ _create_execution_from_pipeline_execution, _extract_output_s3_location_from_steps, ) +from sagemaker.train.common_utils.mlflow_url_utils import get_presigned_mlflow_experiment_url from sagemaker.train.evaluate.constants import EvalType, _get_pipeline_name, _get_pipeline_name_prefix # Test constants @@ -1465,3 +1466,181 @@ def test_complete_get_workflow(self, mock_pe_class, mock_session): # Additional tests for improved coverage - removed as they don't add significant value + + +# ============================================================================ +# Tests for MLflow Link Functions +# ============================================================================ + +MOCK_MLFLOW_ARN = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test-server" +MOCK_PRESIGNED_URL = "https://mlflow.example.com/auth?authToken=abc123" + + +class TestGetMlflowExperimentUrl: + """Tests for get_presigned_mlflow_experiment_url function.""" + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_deep_link_with_experiment(self, mock_sm_client_cls): + """Test returns URL with experiment fragment when experiment exists.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = { + "AuthorizedUrl": MOCK_PRESIGNED_URL + } + mock_sm_client_cls.return_value = mock_client + + mock_experiment = MagicMock() + mock_experiment.experiment_id = "42" + + with patch("mlflow.set_tracking_uri"), \ + patch("mlflow.tracking.MlflowClient") as mock_mlflow_client: + mock_mlflow_client.return_value.get_experiment_by_name.return_value = mock_experiment + + result = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment") + + assert result == f"{MOCK_PRESIGNED_URL}#/experiments/42" + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_base_url_when_no_experiment_name(self, mock_sm_client_cls): + """Test returns base URL when experiment name is None.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = { + "AuthorizedUrl": MOCK_PRESIGNED_URL + } + mock_sm_client_cls.return_value = mock_client + + result = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, None) + + assert result == MOCK_PRESIGNED_URL + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_base_url_when_experiment_lookup_fails(self, mock_sm_client_cls): + """Test falls back to base URL when MLflow experiment lookup raises.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = { + "AuthorizedUrl": MOCK_PRESIGNED_URL + } + mock_sm_client_cls.return_value = mock_client + + with patch("mlflow.set_tracking_uri"), \ + patch("mlflow.tracking.MlflowClient") as mock_mlflow_client: + mock_mlflow_client.return_value.get_experiment_by_name.side_effect = Exception("connection error") + + result = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment") + + assert result == MOCK_PRESIGNED_URL + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_base_url_when_experiment_not_found(self, mock_sm_client_cls): + """Test falls back to base URL when experiment doesn't exist.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = { + "AuthorizedUrl": MOCK_PRESIGNED_URL + } + mock_sm_client_cls.return_value = mock_client + + with patch("mlflow.set_tracking_uri"), \ + patch("mlflow.tracking.MlflowClient") as mock_mlflow_client: + mock_mlflow_client.return_value.get_experiment_by_name.return_value = None + + result = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, "nonexistent") + + assert result == MOCK_PRESIGNED_URL + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_none_when_presigned_url_fails(self, mock_sm_client_cls): + """Test returns None when create_presigned_mlflow_app_url raises.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.side_effect = Exception("access denied") + mock_sm_client_cls.return_value = mock_client + + result = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment") + + assert result is None + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_none_when_authorized_url_empty(self, mock_sm_client_cls): + """Test returns None when AuthorizedUrl is missing from response.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = {} + mock_sm_client_cls.return_value = mock_client + + result = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, "my-experiment") + + assert result is None + + +class TestGetCachedMlflowUrl: + """Tests for the get_cached_mlflow_url closure inside wait().""" + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_cache_hit_within_240s(self, mock_sm_client_cls): + """Test that cached URL is returned within 240s window.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.return_value = { + "AuthorizedUrl": MOCK_PRESIGNED_URL + } + mock_sm_client_cls.return_value = mock_client + + mlflow_link_cache = {"url": None, "timestamp": 0} + + def get_cached(current_time): + if mlflow_link_cache["url"] is None or (current_time - mlflow_link_cache["timestamp"]) > 240: + mlflow_link_cache["url"] = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, None) + mlflow_link_cache["timestamp"] = current_time + return mlflow_link_cache["url"] + + url1 = get_cached(1000.0) + assert url1 == MOCK_PRESIGNED_URL + assert mock_client.sagemaker_client.create_presigned_mlflow_app_url.call_count == 1 + + # Second call within 240s returns cached — no additional API call + url2 = get_cached(1100.0) + assert url2 == url1 + assert mock_client.sagemaker_client.create_presigned_mlflow_app_url.call_count == 1 + + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_cache_refresh_after_240s(self, mock_sm_client_cls): + """Test that URL is regenerated after 240s.""" + mock_client = MagicMock() + mock_client.sagemaker_client.create_presigned_mlflow_app_url.side_effect = [ + {"AuthorizedUrl": MOCK_PRESIGNED_URL}, + {"AuthorizedUrl": "https://mlflow.example.com/auth?authToken=newtoken"}, + ] + mock_sm_client_cls.return_value = mock_client + + mlflow_link_cache = {"url": None, "timestamp": 0} + + def get_cached(current_time): + if mlflow_link_cache["url"] is None or (current_time - mlflow_link_cache["timestamp"]) > 240: + mlflow_link_cache["url"] = get_presigned_mlflow_experiment_url(MOCK_MLFLOW_ARN, None) + mlflow_link_cache["timestamp"] = current_time + return mlflow_link_cache["url"] + + url1 = get_cached(1000.0) + assert url1 == MOCK_PRESIGNED_URL + + # After 241 seconds, should refresh + url2 = get_cached(1241.0) + assert url2 == "https://mlflow.example.com/auth?authToken=newtoken" + assert mock_client.sagemaker_client.create_presigned_mlflow_app_url.call_count == 2 + + def test_returns_none_when_no_mlflow_config(self): + """Test returns None when pipeline execution has no mlflow config.""" + from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute + + mock_pe = MagicMock() + mock_pe.m_lflow_config = MockUnassigned() + + mlflow_cfg = getattr(mock_pe, "m_lflow_config", None) + assert _is_unassigned_attribute(mlflow_cfg) is True + + def test_returns_none_when_arn_is_unassigned(self): + """Test returns None when mlflow_resource_arn is Unassigned.""" + from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute + + mock_mlflow_cfg = MagicMock() + mock_mlflow_cfg.mlflow_resource_arn = MockUnassigned() + mock_mlflow_cfg.__class__ = type("MLflowConfiguration", (), {}) + + assert not _is_unassigned_attribute(mock_mlflow_cfg) + assert _is_unassigned_attribute(mock_mlflow_cfg.mlflow_resource_arn)