diff --git a/src/opltools/cli.py b/src/opltools/cli.py index 4a35d92..b497ffb 100644 --- a/src/opltools/cli.py +++ b/src/opltools/cli.py @@ -1,21 +1,28 @@ import sys import argparse from pydantic import ValidationError +import yaml from pydantic_yaml import parse_yaml_raw_as -from .schema import Library +from opltools.schema import Library + +UNIQUE_FIELDS = ["name"] +UNIQUE_WARNING_FIELDS = ["reference", "implementation"] def cmd_validate(args): - try: - with open(args.file) as f: - raw = f.read() - except OSError as e: - print(f"Error reading file: {e}", file=sys.stderr) - return 1 try: - parse_yaml_raw_as(Library, raw) + with open(args.file, "r") as f: + raw = f.read() + lib = parse_yaml_raw_as(Library, raw) + Library.model_validate( + lib, + context={ + "unique_error_fields": args.unique_error_field, + "unique_warning_fields": args.unique_warning_field, + }, + ) print(f"{args.file}: OK") return 0 except ValidationError as e: @@ -35,6 +42,21 @@ def main(): "validate", help="Validate a YAML file against the Library schema" ) validate_parser.add_argument("file", help="YAML file to validate") + # Add unique error fields + validate_parser.add_argument( + "--unique-error-field", + action="append", + help="Field that must be unique across all entries (can be specified multiple times)", + ) + validate_parser.add_argument( + "--unique-warning-field", + action="append", + help="Field that should be unique across all entries (can be specified multiple times)", + ) + # specify default unique fields if not provided + validate_parser.set_defaults( + unique_error_field=UNIQUE_FIELDS, unique_warning_field=UNIQUE_WARNING_FIELDS + ) args = parser.parse_args() diff --git a/src/opltools/schema.py b/src/opltools/schema.py index a2c5b5b..fe88338 100644 --- a/src/opltools/schema.py +++ b/src/opltools/schema.py @@ -1,7 +1,15 @@ from enum import Enum from typing import Any from typing_extensions import Self -from pydantic import BaseModel, RootModel, ConfigDict, model_validator +from typing import List, Dict, Set +from pydantic import ( + BaseModel, + RootModel, + ConfigDict, + model_validator, + ValidationInfo, + field_validator, +) from .yesnosome import YesNoSome from .utils import ValueRange, union_range @@ -93,6 +101,15 @@ class Usage(BaseModel): code: str +def forbid_value(field: str, forbidden: str): + def validator(cls, v: str): + if v == forbidden: + raise ValueError(f"{field} cannot be '{forbidden}'") + return v + + return field_validator(field)(validator) + + class Implementation(Thing): type: OPLType = OPLType.implementation name: str @@ -102,6 +119,8 @@ class Implementation(Thing): evaluation_time: set[str] | None = None requirements: str | list[str] | None = None + _v = forbid_value("name", "template") # to prevent copy-paste errors + class ProblemLike(Thing): name: str @@ -123,6 +142,8 @@ class ProblemLike(Thing): code_examples: set[str] | None = None source: set[str] | None = None + _v = forbid_value("name", "template") # to prevent copy-paste errors + def __hash__(self): return hash((self.type, self.name)) @@ -141,6 +162,65 @@ class Generator(ProblemLike): type: OPLType = OPLType.generator +class ValidationRule: + def __init__( + self, + field_name: str, + group: List[OPLType] | None, + error_on_duplicate: bool = True, + ): + self.field_name = field_name + self.group = group + self.error_on_duplicate = error_on_duplicate + self.seen = set() + self.duplicates = set() + + def update_seen(self, entry: Thing): + if self.group is None or entry.OPLType in self.group: + value = getattr(entry, self.field_name, None) + if value is None: + return + if value in self.seen: + self.duplicates.add(value) + else: + self.seen.add(value) + + def _process_duplicates(self): + if self.duplicates: + if self.error_on_duplicate: + print( + f"::error::Duplicate values for field '{self.field_name}': {self.duplicates}" + ) + return False + else: + print( + f"::warning::Duplicate values for field '{self.field_name}': {self.duplicates}" + ) + return True + + +class Validator: + def __init__(self, duplicate_settings: List[Dict[str, Any]]): + rules = [] + for setting in duplicate_settings: + field_name = setting["field_name"] + group = setting.get("group", None) + error_on_duplicate = setting.get("error_on_duplicate", True) + rules.append(ValidationRule(field_name, group, error_on_duplicate)) + self.rules = rules + + def update_seen(self, entry: Thing): + for rule in self.rules: + rule.update_seen(entry) + + def process_duplicates(self): + all_valid = True + for rule in self.rules: + if not rule._process_duplicates(): + all_valid = False + return all_valid + + class Library(RootModel): root: dict[str, Problem | Generator | Suite | Implementation] = {} @@ -173,12 +253,33 @@ def _percolate_set(self, thing: Any, children: set | None, property: str): thing_set.update(child_set) @model_validator(mode="after") - def _validate(self) -> Self: + def _validate(self, info: ValidationInfo) -> Self: + # Check for duplicates and # First check and fixup all problems for id, thing in self.root.items(): if isinstance(thing, Problem) and thing.implementations: self._percolate_set(thing, thing.implementations, "evaluation_time") + # Then check and fixup all suites because changes from the problems need to propagate to the suites + duplicate_settings = ( + info.context.get("duplicate_settings", []) if info.context else [] + ) + validator = Validator(duplicate_settings) + + # First check and fixup all problems + for id, thing in self.root.items(): + validator.update_seen(thing) + if isinstance(thing, Problem) and thing.implementations: + self._percolate_set(thing, thing.implementations, "evaluation_time") + + if not validator.process_duplicates(): + raise ValueError( + "Duplicate values found in fields: " + + ", ".join( + rule.field_name for rule in validator.rules if rule.duplicates + ) + ) + # Then check and fixup all suites because changes from the problems need to propagate to the suites for id, thing in self.root.items(): if isinstance(thing, Suite) and thing.problems: @@ -191,7 +292,6 @@ def _validate(self) -> Self: raise ValueError( f"Suite {id} references problem with id '{problem_id}' but id is a {self.root[problem_id].type.name}." ) - self._percolate_set(thing, thing.problems, "fidelity_levels") self._percolate_set(thing, thing.problems, "variables") self._percolate_set(thing, thing.problems, "constraints") diff --git a/tests/test_library.py b/tests/test_library.py index 542bc7b..2c9be02 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -21,13 +21,15 @@ def test_single_problem(self): assert isinstance(lib.root["p1"], Problem) def test_multiple_things(self): - lib = Library(root={ - "p1": Problem(name="P1"), - "p2": Problem(name="P2"), - "g1": Generator(name="G1"), - "s1": Suite(name="S1", problems={"p1", "p2"}), - "impl1": Implementation(name="impl1", description="d"), - }) + lib = Library( + root={ + "p1": Problem(name="P1"), + "p2": Problem(name="P2"), + "g1": Generator(name="G1"), + "s1": Suite(name="S1", problems={"p1", "p2"}), + "impl1": Implementation(name="impl1", description="d"), + } + ) assert len(lib.root) == 5 assert isinstance(lib.root["p1"], Problem) assert isinstance(lib.root["g1"], Generator) @@ -36,63 +38,135 @@ def test_multiple_things(self): def test_suite_references_missing_problem(self): with pytest.raises(ValidationError, match="undefined id"): - Library(root={ - "s1": Suite(name="S1", problems={"does-not-exist"}), - }) + Library( + root={ + "s1": Suite(name="S1", problems={"does-not-exist"}), + } + ) def test_suite_references_non_problem(self): with pytest.raises(ValidationError, match="but id is a"): - Library(root={ - "g1": Generator(name="G1"), - "s1": Suite(name="S1", problems={"g1"}), - }) + Library( + root={ + "g1": Generator(name="G1"), + "s1": Suite(name="S1", problems={"g1"}), + } + ) def test_suite_with_no_problems_is_valid(self): lib = Library(root={"s1": Suite(name="S1")}) assert lib.root["s1"].problems is None def test_fixup_fidelity_populates_from_problems(self): - lib = Library(root={ - "p1": Problem(name="P1", fidelity_levels={1, 2}), - "p2": Problem(name="P2", fidelity_levels={2, 3}), - "s1": Suite(name="S1", problems={"p1", "p2"}), - }) + lib = Library( + root={ + "p1": Problem(name="P1", fidelity_levels={1, 2}), + "p2": Problem(name="P2", fidelity_levels={2, 3}), + "s1": Suite(name="S1", problems={"p1", "p2"}), + } + ) assert lib.root["s1"].fidelity_levels == {1, 2, 3} def test_fixup_fidelity_extends_existing(self): - lib = Library(root={ - "p1": Problem(name="P1", fidelity_levels={5}), - "s1": Suite(name="S1", problems={"p1"}, fidelity_levels={10}), - }) + lib = Library( + root={ + "p1": Problem(name="P1", fidelity_levels={5}), + "s1": Suite(name="S1", problems={"p1"}, fidelity_levels={10}), + } + ) assert lib.root["s1"].fidelity_levels == {5, 10} def test_fixup_fidelity_with_problems_without_levels(self): - lib = Library(root={ - "p1": Problem(name="P1"), - "p2": Problem(name="P2", fidelity_levels={7}), - "s1": Suite(name="S1", problems={"p1", "p2"}), - }) + lib = Library( + root={ + "p1": Problem(name="P1"), + "p2": Problem(name="P2", fidelity_levels={7}), + "s1": Suite(name="S1", problems={"p1", "p2"}), + } + ) assert lib.root["s1"].fidelity_levels == {7} def test_fixup_fidelity_all_problems_without_levels(self): - lib = Library(root={ - "p1": Problem(name="P1"), - "s1": Suite(name="S1", problems={"p1"}), - }) + lib = Library( + root={ + "p1": Problem(name="P1"), + "s1": Suite(name="S1", problems={"p1"}), + } + ) assert lib.root["s1"].fidelity_levels == set() def test_fixup_evaluation_time_percolates_from_implementation_to_suite(self): - lib = Library(root={ - "impl1": Implementation( - name="impl1", description="d", evaluation_time={"fast"} - ), - "impl2": Implementation( - name="impl2", description="d", evaluation_time={"8 minutes"} - ), - "p1": Problem(name="P1", implementations={"impl1"}), - "p2": Problem(name="P2", implementations={"impl2"}), - "s1": Suite(name="S1", problems={"p1", "p2"}), - }) + lib = Library( + root={ + "impl1": Implementation( + name="impl1", description="d", evaluation_time={"fast"} + ), + "impl2": Implementation( + name="impl2", description="d", evaluation_time={"8 minutes"} + ), + "p1": Problem(name="P1", implementations={"impl1"}), + "p2": Problem(name="P2", implementations={"impl2"}), + "s1": Suite(name="S1", problems={"p1", "p2"}), + } + ) assert lib.root["p1"].evaluation_time == {"fast"} assert lib.root["p2"].evaluation_time == {"8 minutes"} assert lib.root["s1"].evaluation_time == {"fast", "8 minutes"} + + +class TestLibraryValidation: + def test_invalid_root_type(self): + with pytest.raises(ValidationError): + Library(root="not a dict") + + def test_invalid_entry_type(self): + with pytest.raises(ValidationError): + Library(root={"p1": "not a problem"}) + + def test_suite_references_nonexistent_problem(self): + with pytest.raises(ValidationError): + Library(root={"s1": Suite(name="S1", problems={"p1"})}) + + def test_suite_references_non_problem(self): + with pytest.raises(ValidationError): + Library( + root={ + "g1": Generator(name="G1"), + "s1": Suite(name="S1", problems={"g1"}), + } + ) + + def test_valid_library(self): + lib = Library( + root={ + "p1": Problem(name="P1", fidelity_levels={1}), + "p2": Problem(name="P2", fidelity_levels={2}), + "s1": Suite(name="S1", problems={"p1", "p2"}), + "g1": Generator(name="G1"), + "impl1": Implementation(name="impl1", description="d"), + } + ) + assert isinstance(lib, Library) + + def test_duplicates(self): + lib = Library( + root={ + "p1": Problem(name="P1", fidelity_levels={1}), + "p2": Problem(name="P1", fidelity_levels={2}), # duplicate name + "s1": Suite(name="S1", problems={"p1", "p2"}), + "g1": Generator(name="G1"), + "impl1": Implementation(name="impl1", description="d"), + "impl2": Implementation( + name="impl1", description="d" + ), # duplicate name + } + ) + assert isinstance(lib, Library) + with pytest.raises(ValidationError): + Library.model_validate( + lib, + context={ + "unique_error_fields": ["name"], + "unique_warning_fields": [], + }, + ) diff --git a/utils/validate_yaml.py b/utils/validate_yaml.py index 34f1899..03abaa5 100644 --- a/utils/validate_yaml.py +++ b/utils/validate_yaml.py @@ -8,7 +8,28 @@ sys.path.insert(0, str(parent)) # Now you can import normally -from yaml_to_html import default_columns as REQUIRED_FIELDS +# from yaml_to_html import default_columns as REQUIRED_FIELDS +REQUIRED_FIELDS = [ + "name", + "textual description", + "suite/generator/single", + "objectives", + "dimensionality", + "variable type", + "constraints", + "dynamic", + "noise", + "multi-fidelity", + "source (real-world/artificial)", + "reference", + "implementation", +] + + +from pydantic import ValidationError +from pydantic_yaml import parse_yaml_raw_as +from src.opltools.schema import Library + OPTIONAL_FIELDS = ["multimodal"] UNIQUE_FIELDS = ["name"] @@ -29,94 +50,68 @@ def read_data(filepath): return 1, None -def check_format(data): - num_problems = len(data) - if len(data) < 1: - print("::error::YAML file should contain at least one top level entry.") - return False - print(f"::notice::YAML file contains {num_problems} top-level entries.") - unique_fields = [] - for i, entry in enumerate(data): - if not isinstance(entry, dict): - print(f"::error::Entry {i} is not a dictionary.") - return False - unique_fields.append({k: v for k, v in entry.items() if k in UNIQUE_FIELDS}) - for k in UNIQUE_FIELDS: - values = [entry[k] for entry in unique_fields] - if len(values) != len(set(values)): - print(f"::error::Field '{k}' must be unique across all entries.") - return False - return True - - -def check_fields(data): - missing = [field for field in REQUIRED_FIELDS if field not in data] - if missing: - print(f"::error::Missing required fields: {', '.join(missing)}") - return False - new_fields = [ - field for field in data if field not in REQUIRED_FIELDS + OPTIONAL_FIELDS - ] - if new_fields: - print(f"::warning::New field added: {', '.join(new_fields)}") - # Check that the name is not still template - if data.get("name") == "template": - print( - "::error::Please change the 'name' field from 'template' to a unique name." - ) - return False - # Check non-empty fields - empty_fields = [ - field - for field in NON_EMPTY_FIELDS - if data.get(field, None) is None or data.get(field, "").strip() == "" - ] - if empty_fields: - print( - f"::error::The following fields cannot be empty: {', '.join(empty_fields)}" - ) - return False - return True - - -def check_novelty(data, checked_data): - for field in UNIQUE_FIELDS + UNIQUE_WARNING_FIELDS: - # skip empty fields - if not data.get(field): +def update_seen(fields, seen, duplicates, entry): + entry_type = entry.get("type", "unknown") + for field in fields: + value = entry.get(field, None) + if value is None: continue - existing_values = { - entry.get(field) for entry in checked_data if isinstance(entry, dict) - } - if data.get(field) in existing_values: - if field in UNIQUE_WARNING_FIELDS: - print( - f"::warning::Field '{field}' with value '{data.get(field)}' already exists. Consider choosing a unique value." - ) - continue - elif field in UNIQUE_FIELDS: - print( - f"::error::Field '{field}' with value '{data.get(field)}' already exists. Please choose a unique value." - ) - return False - return True + seen_value = f"{entry_type}:{value}" + if seen_value in seen[field]: + duplicates[field].add(seen_value) + else: + seen[field].add(seen_value) + return seen, duplicates + + +def check_duplicates(data, warning_fields, error_fields): + # Run checks for each entry and collect duplicates + fields = set(warning_fields + error_fields) + seen = {field: set() for field in fields} + duplicates = {field: set() for field in fields} + for _, entry in data.items(): + seen, duplicates = update_seen(fields, seen, duplicates, entry) + + duplicate_warnings = { + field: list(dups) + for field, dups in duplicates.items() + if dups and field in warning_fields + } + if len(duplicate_warnings) > 0: + print(f"::warning::Duplication warnings {duplicate_warnings}") + duplicate_errors = { + field: list(dups) + for field, dups in duplicates.items() + if dups and field in error_fields + } + if len(duplicate_errors) > 0: + print(f"::error::Duplication errors {duplicate_errors}") + return len(duplicate_errors) == 0 + + +def check_parsing(filepath): + try: + with open(filepath, "r") as f: + raw = f.read() + parse_yaml_raw_as(Library, raw) + return True + except ValidationError as e: + print(f"::error::YAML parsing error: {e}") + return False def validate_yaml(filepath): + status = check_parsing(filepath) + if not status: + sys.exit(1) + status, data = read_data(filepath) - if status != 0: + if status != 0 or data is None: sys.exit(1) - if not check_format(data): + if not check_duplicates( + data, warning_fields=UNIQUE_WARNING_FIELDS, error_fields=UNIQUE_FIELDS + ): sys.exit(1) - assert data is not None - - checked_data = [] - - for i, new_data in enumerate(data): # Iterate through each top-level entry - # Check required and unique fields - if not check_fields(new_data) or not check_novelty(new_data, checked_data): - print(f"::error::Validation failed for entry {i+1}.") - sys.exit(1) - checked_data.append(new_data) # Add to checked data for novelty checks # YAML is valid if we reach this point print("YAML syntax is valid.")