From a0a5b3dc96801a377a2168e08754a854e6266db2 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Fri, 22 May 2026 23:37:01 +0200 Subject: [PATCH] Add subset argument to download datasets --- src/mritk/datasets.py | 49 +++++++++++++++++++ tests/test_datasets.py | 104 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/src/mritk/datasets.py b/src/mritk/datasets.py index ab8c77f..66e9ea2 100644 --- a/src/mritk/datasets.py +++ b/src/mritk/datasets.py @@ -213,6 +213,17 @@ def add_arguments( help=f"Dataset to download (choices: {', '.join(choices)})", ) download_parser.add_argument("-o", "--outdir", type=Path, help="Output directory to download test data") + download_parser.add_argument( + "--subset", + action="append", + metavar="STR", + default=None, + dest="subset", + help=( + "Download only a subset of files. Can be specified multiple times. " + "Accepts filenames with or without extension (e.g. README or README.md)." + ), + ) subparsers.add_parser("list", help="List available datasets") info_parser = subparsers.add_parser("info", help="Show detailed information about a dataset") @@ -227,6 +238,37 @@ def add_arguments( extra_args_cb(info_parser) +def filter_links_by_subset(links: dict[str, str], subsets: list[str]) -> dict[str, str] | None: + """Filter a links dict to only include entries matching the requested subsets. + + Each subset entry is matched by exact filename or by stem (filename without extension). + Returns the filtered dict, or None if any subset could not be matched. + """ + # Build a stem -> filename mapping for efficient lookup + stem_to_filename: dict[str, str] = {Path(k).stem: k for k in links} + + filtered: dict[str, str] = {} + missing: list[str] = [] + + for subset in subsets: + if subset in links: + # Exact match (e.g. "README.md") + filtered[subset] = links[subset] + elif Path(subset).stem in stem_to_filename: + # Stem match (e.g. "README" -> "README.md", "mesh-data" -> "mesh-data.zip") + key = stem_to_filename[Path(subset).stem] + filtered[key] = links[key] + else: + missing.append(subset) + + if missing: + available = ", ".join(links.keys()) + logger.error(f"The following subset(s) were not found: {', '.join(missing)}. Available files: {available}") + return None + + return filtered + + def dispatch(args): subcommand = args.pop("datasets-command", None) if subcommand == "list": @@ -234,6 +276,7 @@ def dispatch(args): elif subcommand == "download": dataset = args.pop("dataset") outdir = args.pop("outdir") + subsets = args.pop("subset", None) if outdir is None: logger.error("Output directory (-o or --outdir) is required for downloading datasets.") return @@ -244,6 +287,12 @@ def dispatch(args): return links = datasets[dataset].links + + if subsets is not None: + links = filter_links_by_subset(links, subsets) + if links is None: + return + download_multiple(links, outdir) elif subcommand == "info": diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ee89632..2600ef0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -152,3 +152,107 @@ def test_download_multiple(mock_download_data, mock_executor, tmp_path): successful = mritk.datasets.download_multiple(urls, tmp_path) assert len(successful) == 2 + + +# --- Tests for filter_links_by_subset --- + + +@pytest.fixture +def sample_links(): + return { + "README.md": "http://example.com/README.md", + "mesh-data.zip": "http://example.com/mesh-data.zip", + "archive.zip": "http://example.com/archive.zip", + } + + +def test_filter_links_exact_match(sample_links): + """Exact filename (with extension) should match correctly.""" + result = mritk.datasets.filter_links_by_subset(sample_links, ["README.md"]) + assert result == {"README.md": sample_links["README.md"]} + + +def test_filter_links_stem_match(sample_links): + """Filename without extension should match the full filename.""" + result = mritk.datasets.filter_links_by_subset(sample_links, ["README"]) + assert result == {"README.md": sample_links["README.md"]} + + +def test_filter_links_stem_match_zip(sample_links): + """Stem of a zip file should match the full filename.""" + result = mritk.datasets.filter_links_by_subset(sample_links, ["mesh-data"]) + assert result == {"mesh-data.zip": sample_links["mesh-data.zip"]} + + +def test_filter_links_multiple_subsets(sample_links): + """Multiple subsets (mixed exact and stem) should all be resolved.""" + result = mritk.datasets.filter_links_by_subset(sample_links, ["README.md", "mesh-data"]) + assert result == { + "README.md": sample_links["README.md"], + "mesh-data.zip": sample_links["mesh-data.zip"], + } + + +def test_filter_links_missing_subset_returns_none(sample_links, caplog): + """A subset that does not exist should return None and log an error.""" + result = mritk.datasets.filter_links_by_subset(sample_links, ["nonexistent"]) + assert result is None + assert "nonexistent" in caplog.text + + +def test_filter_links_partial_missing_returns_none(sample_links, caplog): + """If even one subset is missing, nothing should be downloaded (return None).""" + result = mritk.datasets.filter_links_by_subset(sample_links, ["README", "nonexistent"]) + assert result is None + assert "nonexistent" in caplog.text + + +# --- Tests for dispatch with --subset --- + + +@patch("mritk.datasets.get_datasets") +@patch("mritk.datasets.download_multiple") +def test_dispatch_download_with_subset_stem(mock_download_multiple, mock_get_datasets, mock_datasets): + """--subset with stem only should filter links and call download_multiple.""" + mock_get_datasets.return_value = mock_datasets + + mritk.cli.main(["datasets", "download", "test-data", "-o", "/tmp", "--subset", "file1"]) + + expected_links = {"file1.txt": mock_datasets["test-data"].links["file1.txt"]} + mock_download_multiple.assert_called_once_with(expected_links, Path("/tmp")) + + +@patch("mritk.datasets.get_datasets") +@patch("mritk.datasets.download_multiple") +def test_dispatch_download_with_subset_exact(mock_download_multiple, mock_get_datasets, mock_datasets): + """--subset with exact filename should filter links correctly.""" + mock_get_datasets.return_value = mock_datasets + + mritk.cli.main(["datasets", "download", "test-data", "-o", "/tmp", "--subset", "file1.txt"]) + + expected_links = {"file1.txt": mock_datasets["test-data"].links["file1.txt"]} + mock_download_multiple.assert_called_once_with(expected_links, Path("/tmp")) + + +@patch("mritk.datasets.get_datasets") +@patch("mritk.datasets.download_multiple") +def test_dispatch_download_with_multiple_subsets(mock_download_multiple, mock_get_datasets, mock_datasets): + """Multiple --subset flags should download only the requested files.""" + mock_get_datasets.return_value = mock_datasets + + mritk.cli.main(["datasets", "download", "test-data", "-o", "/tmp", "--subset", "file1", "--subset", "archive"]) + + expected_links = mock_datasets["test-data"].links # both files + mock_download_multiple.assert_called_once_with(expected_links, Path("/tmp")) + + +@patch("mritk.datasets.get_datasets") +@patch("mritk.datasets.download_multiple") +def test_dispatch_download_subset_not_found(mock_download_multiple, mock_get_datasets, mock_datasets, caplog): + """If a --subset does not exist, download_multiple must NOT be called.""" + mock_get_datasets.return_value = mock_datasets + + mritk.cli.main(["datasets", "download", "test-data", "-o", "/tmp", "--subset", "does-not-exist"]) + + mock_download_multiple.assert_not_called() + assert "does-not-exist" in caplog.text