diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 645eb6679..c81a59b6f 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -975,18 +975,76 @@ def _resolve_ref(ref_path: str, defs: dict) -> dict: def _is_complex_anyof(v: dict) -> bool: - """Check if anyOf contains complex types (refs or nested objects).""" + """Check if anyOf contains complex types (refs, oneOf, or nested objects).""" any_of_schemas = v.get("anyOf", []) for sub_schema in any_of_schemas: # Skip null types - they just indicate optionality if sub_schema.get("type") == "null": continue - # Check for references or nested properties (don't recursively check allOf) - if "$ref" in sub_schema or "properties" in sub_schema: + # Check for references, nested properties, or oneOf branches + # (don't recursively check allOf). oneOf appears in Pydantic discriminated + # unions: ``Annotated[A | B, Field(discriminator=...)] | None``. + if "$ref" in sub_schema or "properties" in sub_schema or "oneOf" in sub_schema: return True return False +def _flatten_discriminated_union(v: dict, defs: dict) -> dict: + """Normalise Pydantic discriminated-union schemas for tool-calling APIs. + + Pydantic emits ``Annotated[A | B, Field(discriminator="kind")]`` as either: + + - Required: ``{"discriminator": {...}, "oneOf": [{"$ref": ...}, ...]}`` + - Optional: ``{"anyOf": [{"discriminator": {...}, "oneOf": [...]}, {"type": "null"}]}`` + + The OAS-3 ``discriminator`` keyword and the JSON Schema ``oneOf`` keyword + both fall outside the schema subset accepted by tool-calling APIs (Ollama, + OpenAI strict mode). The ``Literal`` constraint on the tag field carries the + discriminator signal, so we can safely drop ``discriminator`` and emit + ``anyOf`` — equivalent here because the tag field forces uniqueness. + + Returns a new schema dict with ``oneOf`` rewritten to ``anyOf``, branches + inlined, and ``discriminator`` stripped. The input is not mutated; callers + must reassign the result. + + Flattening is single-level. Discriminated unions nested inside an inlined + branch (e.g. a Pydantic model whose own field is another discriminated + union) are not recursively flattened — that case is tracked alongside the + recursive ``$ref`` resolution work in #911. + """ + + def _inline(branch: dict) -> dict: + if "$ref" in branch: + ref_schema = _resolve_ref(branch["$ref"], defs) + if ref_schema: + return copy.deepcopy(ref_schema) + return copy.deepcopy(branch) + + out = {kk: vv for kk, vv in v.items() if kk not in ("oneOf", "discriminator")} + + # Top-level discriminated union (required parameter case). Append rather + # than overwrite so a defensive ``oneOf`` + ``anyOf`` co-occurrence does + # not silently drop the existing ``anyOf`` entries. Pydantic does not + # currently emit both at the same level, but the helper is callable in + # isolation and should not lose data. + if "oneOf" in v: + existing = list(out.get("anyOf", [])) + out["anyOf"] = existing + [_inline(b) for b in v["oneOf"]] + + # Nested oneOf inside anyOf (Optional discriminated union case) + if "anyOf" in out: + flattened: list[dict] = [] + for sub in out["anyOf"]: + if "oneOf" in sub: + for branch in sub["oneOf"]: + flattened.append(_inline(branch)) + else: + flattened.append(sub) + out["anyOf"] = flattened + + return out + + # https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_utils.py#L56-L90 def convert_function_to_ollama_tool( func: Callable, name: str | None = None @@ -1021,6 +1079,12 @@ def convert_function_to_ollama_tool( defs = schema.get("$defs", schema.get("definitions", {})) for k, v in schema.get("properties", {}).items(): + # Pre-pass: flatten Pydantic discriminated unions (oneOf + discriminator) + # into plain anyOf with branches inlined. See _flatten_discriminated_union. + if "oneOf" in v or ("anyOf" in v and any("oneOf" in s for s in v["anyOf"])): + v = _flatten_discriminated_union(v, defs) + schema["properties"][k] = v + # First pass: inline all $refs (at top level and within anyOf) if "$ref" in v: # Resolve the reference and inline it diff --git a/test/backends/test_discriminated_union_tools.py b/test/backends/test_discriminated_union_tools.py new file mode 100644 index 000000000..92305ae06 --- /dev/null +++ b/test/backends/test_discriminated_union_tools.py @@ -0,0 +1,274 @@ +"""End-to-end tests for discriminated-union tool parameters. + +Covers issue #989: a tool parameter typed as a Pydantic discriminated union +``Annotated[A | B, Field(discriminator="kind")]`` (with or without ``| None``) +must not collapse to ``{"type": "string"}``. The schema produced by +``convert_function_to_ollama_tool`` is consumed by every backend +(Ollama, OpenAI, Watsonx, HuggingFace, LiteLLM), so the union structure must +be preserved and the OAS-3 ``discriminator`` keyword must be stripped from +the output (the JSON Schema subset accepted by tool-calling APIs does not +include it; the ``Literal`` tag fields carry the discriminator signal). +""" + +import json +from typing import Annotated, Literal + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from mellea.backends.tools import ( + MelleaTool, + convert_function_to_ollama_tool, + validate_tool_arguments, +) + + +class Cat(BaseModel): + kind: Literal["cat"] + name: str + + +class Dog(BaseModel): + kind: Literal["dog"] + name: str + breed: str + + +class Fish(BaseModel): + kind: Literal["fish"] + name: str + species: str + + +class Email(BaseModel): + """Non-discriminated nested model for the no-op regression test.""" + + to: str + subject: str + + +def act(pet: Annotated[Cat | Dog, Field(discriminator="kind")]) -> str: + """Act on a pet. + + Args: + pet: the pet to act on + """ + return "ok" + + +def act_optional( + pet: Annotated[Cat | Dog, Field(discriminator="kind")] | None = None, +) -> str: + """Act on an optional pet. + + Args: + pet: the pet to act on, may be omitted + """ + return "ok" + + +def _pet_schema(func) -> dict: + """Convert ``func`` and return the ``pet`` parameter schema.""" + tool = convert_function_to_ollama_tool(func) + assert tool.function is not None + assert tool.function.parameters is not None + return tool.function.parameters.model_dump(exclude_none=True)["properties"]["pet"] + + +def _has_branch(schema: dict, kind_value: str, *, must_have: set[str]) -> bool: + """Check that ``schema`` contains an inlined ``anyOf`` branch for ``kind_value``. + + After the fix lands the output schema must contain ``anyOf`` only, never + ``oneOf`` — accepting ``oneOf`` here would silently mask a regression of + the discriminator-flattening pre-pass. + """ + branches = schema.get("anyOf", []) + for branch in branches: + props = branch.get("properties", {}) + kind = props.get("kind", {}) + if kind.get("const") == kind_value or kind_value in (kind.get("enum") or []): + return must_have.issubset(set(props.keys())) + return False + + +class TestDiscriminatedUnionSchema: + """Schema-shape assertions for discriminated-union tool parameters.""" + + def test_required_union_does_not_collapse_to_string(self): + """The discriminated union must not be flattened to a primitive.""" + pet = _pet_schema(act) + assert pet.get("type") != "string", ( + f"discriminated union collapsed to a string schema: {pet!r}" + ) + + def test_required_union_preserves_branches(self): + """Both Cat and Dog branches must survive as inlined object schemas.""" + pet = _pet_schema(act) + assert "anyOf" in pet or "oneOf" in pet, f"expected anyOf/oneOf in {pet!r}" + assert _has_branch(pet, "cat", must_have={"kind", "name"}), ( + f"Cat branch missing or unresolved: {pet!r}" + ) + assert _has_branch(pet, "dog", must_have={"kind", "name", "breed"}), ( + f"Dog branch missing or unresolved: {pet!r}" + ) + + def test_required_union_strips_discriminator_keyword(self): + """OAS-3 ``discriminator`` is rejected by Ollama / OpenAI strict mode. + + The ``Literal`` constraint on ``kind`` already carries the tag signal, + so the OAS keyword adds no semantic value but is actively harmful. + """ + pet = _pet_schema(act) + assert "discriminator" not in pet, ( + f"discriminator keyword should be stripped from output: {pet!r}" + ) + + def test_required_union_no_dangling_refs(self): + """No ``$ref`` should leak into the output for the issue reproducer.""" + rendered = json.dumps(_pet_schema(act)) + assert "$ref" not in rendered, f"unresolved $ref in tool schema: {rendered}" + + def test_optional_union_does_not_collapse_to_string(self): + """The Optional variant also must not flatten to a primitive.""" + pet = _pet_schema(act_optional) + # Either pet is itself a discriminated union schema with a null branch, + # or it is anyOf:[, null]. Either way, "string" alone is wrong. + assert pet.get("type") != "string", ( + f"optional discriminated union collapsed to a string schema: {pet!r}" + ) + + def test_optional_union_preserves_branches(self): + """The Optional variant must preserve both inlined object branches.""" + pet = _pet_schema(act_optional) + assert _has_branch(pet, "cat", must_have={"kind", "name"}), ( + f"Cat branch missing in optional variant: {pet!r}" + ) + assert _has_branch(pet, "dog", must_have={"kind", "name", "breed"}), ( + f"Dog branch missing in optional variant: {pet!r}" + ) + + def test_optional_union_drops_from_required(self): + """The optional parameter must not be in the function's required list.""" + tool = convert_function_to_ollama_tool(act_optional) + assert tool.function is not None + assert tool.function.parameters is not None + params = tool.function.parameters.model_dump(exclude_none=True) + assert "pet" not in params.get("required", []), ( + f"optional 'pet' should not be required: {params}" + ) + + def test_optional_union_strips_discriminator_keyword(self): + """The Optional variant must also drop the OAS-3 ``discriminator``. + + The required variant strips it via the top-level ``oneOf`` path; the + optional variant strips it implicitly when the wrapper sub-schema is + replaced by its expanded branches. Asserted explicitly so a refactor + that re-introduces the wrapper does not slip past silently. + """ + rendered = json.dumps(_pet_schema(act_optional)) + assert "discriminator" not in rendered, ( + f"discriminator keyword should be stripped from optional output: {rendered}" + ) + + def test_three_way_union_preserves_all_branches(self): + """A three-arm discriminated union must preserve all three branches.""" + + def act_three( + pet: Annotated[Cat | Dog | Fish, Field(discriminator="kind")], + ) -> str: + """Act on a three-way pet. + + Args: + pet: the pet to act on + """ + return "ok" + + pet = _pet_schema(act_three) + assert _has_branch(pet, "cat", must_have={"kind", "name"}), ( + f"Cat branch missing in three-way union: {pet!r}" + ) + assert _has_branch(pet, "dog", must_have={"kind", "name", "breed"}), ( + f"Dog branch missing in three-way union: {pet!r}" + ) + assert _has_branch(pet, "fish", must_have={"kind", "name", "species"}), ( + f"Fish branch missing in three-way union: {pet!r}" + ) + + def test_non_discriminated_optional_unchanged(self): + """Non-discriminated ``Optional[Email]`` must still flow through unchanged. + + Regression guard: the new pre-pass must be a no-op for plain + ``$ref`` + ``| None`` shapes that the existing inliner already + handles. Pydantic emits this as + ``{"anyOf": [{"$ref": "..."}, {"type": "null"}]}`` — no ``oneOf`` + in any sub-schema, so the pre-pass should not activate. + """ + + def send(email: Email | None = None) -> str: + """Send an email. + + Args: + email: optional email payload + """ + return "sent" + + tool = convert_function_to_ollama_tool(send) + assert tool.function is not None + assert tool.function.parameters is not None + rendered = tool.function.parameters.model_dump(exclude_none=True) + email_schema = rendered["properties"]["email"] + # The existing complex-anyOf path inlines the $ref and preserves the + # full object schema with properties. The exact shape is owned by the + # pre-existing logic; we only assert the pre-pass did not collapse it. + assert email_schema.get("type") != "string", ( + f"non-discriminated Optional collapsed: {email_schema!r}" + ) + assert "email" not in rendered.get("required", []), ( + f"optional email should not be required: {rendered}" + ) + + +class TestDiscriminatedUnionValidation: + """``validate_tool_arguments`` must round-trip a valid discriminated payload.""" + + def test_strict_accepts_valid_dog(self): + """A correctly-shaped dog dict should pass strict validation.""" + mt = MelleaTool.from_callable(act) + validate_tool_arguments( + mt, {"pet": {"kind": "dog", "name": "Rex", "breed": "lab"}}, strict=True + ) + + def test_strict_accepts_valid_cat(self): + """A correctly-shaped cat dict should pass strict validation.""" + mt = MelleaTool.from_callable(act) + validate_tool_arguments( + mt, {"pet": {"kind": "cat", "name": "Whiskers"}}, strict=True + ) + + def test_strict_rejects_bare_string(self): + """A bare string was the bug's silent-pass: must now be rejected.""" + mt = MelleaTool.from_callable(act) + with pytest.raises(ValidationError): + validate_tool_arguments(mt, {"pet": "just a string"}, strict=True) + + def test_strict_rejects_missing_discriminator(self): + """A dict without the ``kind`` discriminator must be rejected.""" + mt = MelleaTool.from_callable(act) + with pytest.raises(ValidationError): + validate_tool_arguments(mt, {"pet": {"name": "Rex"}}, strict=True) + + def test_optional_accepts_omitted(self): + """The optional variant accepts the parameter being omitted.""" + mt = MelleaTool.from_callable(act_optional) + validate_tool_arguments(mt, {}, strict=True) + + def test_optional_accepts_valid_payload(self): + """The optional variant accepts a valid payload.""" + mt = MelleaTool.from_callable(act_optional) + validate_tool_arguments( + mt, {"pet": {"kind": "dog", "name": "Rex", "breed": "lab"}}, strict=True + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/backends/test_schema_helpers.py b/test/backends/test_schema_helpers.py index 73dced767..193ca6c28 100644 --- a/test/backends/test_schema_helpers.py +++ b/test/backends/test_schema_helpers.py @@ -163,6 +163,38 @@ def test_complex_nested_structure(self): } assert _is_complex_anyof(schema) is True + def test_anyof_with_oneof_branch_is_complex(self): + """anyOf containing a oneOf branch (Optional discriminated union) is complex. + + Pydantic emits ``Annotated[Cat | Dog, Field(discriminator="kind")] | None`` + as ``{"anyOf": [{"discriminator": {...}, "oneOf": [$refs]}, {"type": "null"}]}``. + Without descending into ``oneOf`` we drop the union structure and ship + ``{"type": "string"}`` to the backend (issue #989). + """ + schema = { + "anyOf": [ + { + "discriminator": { + "propertyName": "kind", + "mapping": {"cat": "#/$defs/Cat", "dog": "#/$defs/Dog"}, + }, + "oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}], + }, + {"type": "null"}, + ] + } + assert _is_complex_anyof(schema) is True + + def test_anyof_with_plain_oneof_branch_is_complex(self): + """anyOf containing a oneOf branch without discriminator metadata is complex.""" + schema = { + "anyOf": [ + {"oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}]}, + {"type": "null"}, + ] + } + assert _is_complex_anyof(schema) is True + if __name__ == "__main__": pytest.main([__file__])