From 6f9280c286b072576614eed1b627494938ad9745 Mon Sep 17 00:00:00 2001 From: poofeth <91919121+poofeth@users.noreply.github.com> Date: Sun, 10 May 2026 23:37:30 -0500 Subject: [PATCH] Support cross-entity checkpoint forks --- src/art/serverless/backend.py | 26 +++++- tests/unit/test_serverless_fork_checkpoint.py | 92 +++++++++++++++++++ 2 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_serverless_fork_checkpoint.py diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index fc95300d..0d45e50e 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -40,6 +40,20 @@ def _extract_step_from_wandb_artifact(artifact: "wandb.Artifact") -> int | None: return None +def _wandb_checkpoint_collection_path( + *, + from_model: str, + from_project: str, + model_entity: str | None, + default_entity: str | None, + from_entity: str | None = None, +) -> str: + resolved_entity = from_entity or model_entity or default_entity + if resolved_entity is None: + raise ValueError("A W&B entity is required to locate the source checkpoint") + return f"{resolved_entity}/{from_project}/{from_model}" + + _UPSTREAM_TRAIN_METRIC_KEYS = { "reward": "reward", "reward_std_dev": "reward_std_dev", @@ -728,6 +742,7 @@ async def _experimental_fork_checkpoint( model: "Model", from_model: str, from_project: str | None = None, + from_entity: str | None = None, from_s3_bucket: str | None = None, not_after_step: int | None = None, verbose: bool = False, @@ -746,6 +761,8 @@ async def _experimental_fork_checkpoint( model: The destination model to fork to. from_model: The name of the source model to fork from. from_project: The project of the source model. Defaults to model.project. + from_entity: Optional W&B entity for the source model. Defaults to the + destination model's entity, then the W&B default entity. from_s3_bucket: Optional S3 bucket to pull the checkpoint from. not_after_step: If provided, uses the latest checkpoint <= this step. verbose: Whether to print verbose output. @@ -812,12 +829,17 @@ async def _experimental_fork_checkpoint( else: # Pull from W&B artifacts api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] - from_entity = model.entity or api.default_entity # Iterate all artifact versions to find the best step. # We avoid relying on the W&B `:latest` alias because it # may not correspond to the highest training step. - collection_path = f"{from_entity}/{from_project}/{from_model}" + collection_path = _wandb_checkpoint_collection_path( + from_model=from_model, + from_project=from_project, + from_entity=from_entity, + model_entity=model.entity, + default_entity=api.default_entity, + ) versions = api.artifacts("lora", collection_path) best_step: int | None = None diff --git a/tests/unit/test_serverless_fork_checkpoint.py b/tests/unit/test_serverless_fork_checkpoint.py new file mode 100644 index 00000000..4b4e70b6 --- /dev/null +++ b/tests/unit/test_serverless_fork_checkpoint.py @@ -0,0 +1,92 @@ +import sys +from types import SimpleNamespace + +import pytest + +from art.serverless.backend import ( + ServerlessBackend, + _wandb_checkpoint_collection_path, +) + + +def test_checkpoint_collection_path_prefers_explicit_source_entity(): + path = _wandb_checkpoint_collection_path( + from_model="source-model", + from_project="source-project", + from_entity="source-entity", + model_entity="destination-entity", + default_entity="default-entity", + ) + + assert path == "source-entity/source-project/source-model" + + +def test_checkpoint_collection_path_falls_back_to_destination_entity(): + path = _wandb_checkpoint_collection_path( + from_model="source-model", + from_project="source-project", + from_entity=None, + model_entity="destination-entity", + default_entity="default-entity", + ) + + assert path == "destination-entity/source-project/source-model" + + +def test_checkpoint_collection_path_falls_back_to_default_entity(): + path = _wandb_checkpoint_collection_path( + from_model="source-model", + from_project="source-project", + from_entity=None, + model_entity=None, + default_entity="default-entity", + ) + + assert path == "default-entity/source-project/source-model" + + +def test_checkpoint_collection_path_requires_an_entity(): + with pytest.raises(ValueError, match="W&B entity"): + _wandb_checkpoint_collection_path( + from_model="source-model", + from_project="source-project", + from_entity=None, + model_entity=None, + default_entity=None, + ) + + +@pytest.mark.asyncio +async def test_fork_checkpoint_uses_explicit_source_entity(monkeypatch): + artifact_calls = [] + + class FakeApi: + default_entity = "default-entity" + + def __init__(self, api_key): + assert api_key == "test-api-key" + + def artifacts(self, artifact_type, collection_path): + artifact_calls.append((artifact_type, collection_path)) + return [] + + fake_wandb = SimpleNamespace(Api=FakeApi) + monkeypatch.setitem(sys.modules, "wandb", fake_wandb) + + backend = ServerlessBackend.__new__(ServerlessBackend) + backend._client = SimpleNamespace(api_key="test-api-key") + model = SimpleNamespace( + entity="destination-entity", + project="destination-project", + name="destination-model", + ) + + with pytest.raises(ValueError, match="No checkpoints found"): + await backend._experimental_fork_checkpoint( + model, + from_model="source-model", + from_project="source-project", + from_entity="source-entity", + ) + + assert artifact_calls == [("lora", "source-entity/source-project/source-model")]