From 0757090d0bfd166fbdb1827154799a4576b89153 Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 20 May 2026 11:31:12 +0100 Subject: [PATCH 1/2] fix: preserve discriminated-union schema in tool conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A tool parameter typed as a Pydantic discriminated union (``Annotated[A | B, Field(discriminator="kind")]``, with or without ``| None``) currently collapses to ``{"type": "string"}`` in the schema emitted by ``convert_function_to_ollama_tool``. Because that schema is shared by every backend (Ollama, OpenAI, Watsonx, HuggingFace, LiteLLM), discriminated-union tool parameters are silently broken across the board: the model sees a string and hallucinates a payload, and ``validate_tool_arguments`` rejects valid dicts. Pydantic emits the union as ``oneOf`` plus an OAS-3 ``discriminator`` keyword, neither of which is in the JSON Schema subset accepted by tool-calling APIs. The existing inliner only descends into ``anyOf`` and ``$ref``, so the structure falls through to the primitive-flattening branch. Fix: add a pre-pass that flattens the discriminated-union shapes — both top-level (required) and nested-in-anyOf (Optional) — to plain ``anyOf`` of inlined object schemas, with the OAS ``discriminator`` keyword stripped. The ``Literal`` constraints on the tag field already carry the discriminator signal. Also extends ``_is_complex_anyof`` to detect ``oneOf`` branches defensively. Resolves #989. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/backends/tools.py | 62 +++++- .../test_discriminated_union_tools.py | 183 ++++++++++++++++++ test/backends/test_schema_helpers.py | 32 +++ 3 files changed, 274 insertions(+), 3 deletions(-) create mode 100644 test/backends/test_discriminated_union_tools.py diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index 645eb6679..a32374067 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -975,18 +975,66 @@ def _resolve_ref(ref_path: str, defs: dict) -> dict: def _is_complex_anyof(v: dict) -> bool: - """Check if anyOf contains complex types (refs or nested objects).""" + """Check if anyOf contains complex types (refs, oneOf, or nested objects).""" any_of_schemas = v.get("anyOf", []) for sub_schema in any_of_schemas: # Skip null types - they just indicate optionality if sub_schema.get("type") == "null": continue - # Check for references or nested properties (don't recursively check allOf) - if "$ref" in sub_schema or "properties" in sub_schema: + # Check for references, nested properties, or oneOf branches + # (don't recursively check allOf). oneOf appears in Pydantic discriminated + # unions: ``Annotated[A | B, Field(discriminator=...)] | None``. + if "$ref" in sub_schema or "properties" in sub_schema or "oneOf" in sub_schema: return True return False +def _flatten_discriminated_union(v: dict, defs: dict) -> dict: + """Normalise Pydantic discriminated-union schemas for tool-calling APIs. + + Pydantic emits ``Annotated[A | B, Field(discriminator="kind")]`` as either: + + - Required: ``{"discriminator": {...}, "oneOf": [{"$ref": ...}, ...]}`` + - Optional: ``{"anyOf": [{"discriminator": {...}, "oneOf": [...]}, {"type": "null"}]}`` + + The OAS-3 ``discriminator`` keyword and the JSON Schema ``oneOf`` keyword + both fall outside the schema subset accepted by tool-calling APIs (Ollama, + OpenAI strict mode). The ``Literal`` constraint on the tag field carries the + discriminator signal, so we can safely drop ``discriminator`` and emit + ``anyOf`` — equivalent here because the tag field forces uniqueness. + + This rewrites the schema in place: ``oneOf`` becomes ``anyOf`` with branches + inlined, ``discriminator`` is stripped, and Optional unions flatten to + ``{"anyOf": [...inlined_branches, {"type": "null"}]}``. + """ + + def _inline(branch: dict) -> dict: + if "$ref" in branch: + ref_schema = _resolve_ref(branch["$ref"], defs) + if ref_schema: + return copy.deepcopy(ref_schema) + return copy.deepcopy(branch) + + # Top-level discriminated union (required parameter case) + if "oneOf" in v: + branches = [_inline(b) for b in v["oneOf"]] + v = {kk: vv for kk, vv in v.items() if kk not in ("oneOf", "discriminator")} + v["anyOf"] = branches + + # Nested oneOf inside anyOf (Optional discriminated union case) + if "anyOf" in v: + flattened: list[dict] = [] + for sub in v["anyOf"]: + if "oneOf" in sub: + for branch in sub["oneOf"]: + flattened.append(_inline(branch)) + else: + flattened.append(sub) + v["anyOf"] = flattened + + return v + + # https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_utils.py#L56-L90 def convert_function_to_ollama_tool( func: Callable, name: str | None = None @@ -1021,6 +1069,14 @@ def convert_function_to_ollama_tool( defs = schema.get("$defs", schema.get("definitions", {})) for k, v in schema.get("properties", {}).items(): + # Pre-pass: flatten Pydantic discriminated unions (oneOf + discriminator) + # into plain anyOf with branches inlined. See _flatten_discriminated_union. + if "oneOf" in v or ( + "anyOf" in v and any("oneOf" in s for s in v.get("anyOf", [])) + ): + v = _flatten_discriminated_union(v, defs) + schema["properties"][k] = v + # First pass: inline all $refs (at top level and within anyOf) if "$ref" in v: # Resolve the reference and inline it diff --git a/test/backends/test_discriminated_union_tools.py b/test/backends/test_discriminated_union_tools.py new file mode 100644 index 000000000..a07adf181 --- /dev/null +++ b/test/backends/test_discriminated_union_tools.py @@ -0,0 +1,183 @@ +"""End-to-end tests for discriminated-union tool parameters. + +Covers issue #989: a tool parameter typed as a Pydantic discriminated union +``Annotated[A | B, Field(discriminator="kind")]`` (with or without ``| None``) +must not collapse to ``{"type": "string"}``. The schema produced by +``convert_function_to_ollama_tool`` is consumed by every backend +(Ollama, OpenAI, Watsonx, HuggingFace, LiteLLM), so the union structure must +be preserved and the OAS-3 ``discriminator`` keyword must be stripped from +the output (the JSON Schema subset accepted by tool-calling APIs does not +include it; the ``Literal`` tag fields carry the discriminator signal). +""" + +from typing import Annotated, Literal + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from mellea.backends.tools import ( + MelleaTool, + convert_function_to_ollama_tool, + validate_tool_arguments, +) + + +class Cat(BaseModel): + kind: Literal["cat"] + name: str + + +class Dog(BaseModel): + kind: Literal["dog"] + name: str + breed: str + + +def act(pet: Annotated[Cat | Dog, Field(discriminator="kind")]) -> str: + """Act on a pet. + + Args: + pet: the pet to act on + """ + return "ok" + + +def act_optional( + pet: Annotated[Cat | Dog, Field(discriminator="kind")] | None = None, +) -> str: + """Act on an optional pet. + + Args: + pet: the pet to act on, may be omitted + """ + return "ok" + + +def _pet_schema(func) -> dict: + """Convert ``func`` and return the ``pet`` parameter schema.""" + tool = convert_function_to_ollama_tool(func) + return tool.function.parameters.model_dump(exclude_none=True)["properties"]["pet"] + + +def _has_branch(schema: dict, kind_value: str, *, must_have: set[str]) -> bool: + """Check that ``schema`` contains an inlined object branch for ``kind_value``.""" + branches = schema.get("anyOf") or schema.get("oneOf") or [] + for branch in branches: + props = branch.get("properties", {}) + kind = props.get("kind", {}) + if kind.get("const") == kind_value or kind_value in (kind.get("enum") or []): + return must_have.issubset(set(props.keys())) + return False + + +class TestDiscriminatedUnionSchema: + """Schema-shape assertions for discriminated-union tool parameters.""" + + def test_required_union_does_not_collapse_to_string(self): + """The discriminated union must not be flattened to a primitive.""" + pet = _pet_schema(act) + assert pet.get("type") != "string", ( + f"discriminated union collapsed to a string schema: {pet!r}" + ) + + def test_required_union_preserves_branches(self): + """Both Cat and Dog branches must survive as inlined object schemas.""" + pet = _pet_schema(act) + assert "anyOf" in pet or "oneOf" in pet, f"expected anyOf/oneOf in {pet!r}" + assert _has_branch(pet, "cat", must_have={"kind", "name"}), ( + f"Cat branch missing or unresolved: {pet!r}" + ) + assert _has_branch(pet, "dog", must_have={"kind", "name", "breed"}), ( + f"Dog branch missing or unresolved: {pet!r}" + ) + + def test_required_union_strips_discriminator_keyword(self): + """OAS-3 ``discriminator`` is rejected by Ollama / OpenAI strict mode. + + The ``Literal`` constraint on ``kind`` already carries the tag signal, + so the OAS keyword adds no semantic value but is actively harmful. + """ + pet = _pet_schema(act) + assert "discriminator" not in pet, ( + f"discriminator keyword should be stripped from output: {pet!r}" + ) + + def test_required_union_no_dangling_refs(self): + """No ``$ref`` should leak into the output for the issue reproducer.""" + import json + + rendered = json.dumps(_pet_schema(act)) + assert "$ref" not in rendered, f"unresolved $ref in tool schema: {rendered}" + + def test_optional_union_does_not_collapse_to_string(self): + """The Optional variant also must not flatten to a primitive.""" + pet = _pet_schema(act_optional) + # Either pet is itself a discriminated union schema with a null branch, + # or it is anyOf:[, null]. Either way, "string" alone is wrong. + assert pet.get("type") != "string", ( + f"optional discriminated union collapsed to a string schema: {pet!r}" + ) + + def test_optional_union_preserves_branches(self): + """The Optional variant must preserve both inlined object branches.""" + pet = _pet_schema(act_optional) + assert _has_branch(pet, "cat", must_have={"kind", "name"}), ( + f"Cat branch missing in optional variant: {pet!r}" + ) + assert _has_branch(pet, "dog", must_have={"kind", "name", "breed"}), ( + f"Dog branch missing in optional variant: {pet!r}" + ) + + def test_optional_union_drops_from_required(self): + """The optional parameter must not be in the function's required list.""" + tool = convert_function_to_ollama_tool(act_optional) + params = tool.function.parameters.model_dump(exclude_none=True) + assert "pet" not in params.get("required", []), ( + f"optional 'pet' should not be required: {params}" + ) + + +class TestDiscriminatedUnionValidation: + """``validate_tool_arguments`` must round-trip a valid discriminated payload.""" + + def test_strict_accepts_valid_dog(self): + """A correctly-shaped dog dict should pass strict validation.""" + mt = MelleaTool.from_callable(act) + validate_tool_arguments( + mt, {"pet": {"kind": "dog", "name": "Rex", "breed": "lab"}}, strict=True + ) + + def test_strict_accepts_valid_cat(self): + """A correctly-shaped cat dict should pass strict validation.""" + mt = MelleaTool.from_callable(act) + validate_tool_arguments( + mt, {"pet": {"kind": "cat", "name": "Whiskers"}}, strict=True + ) + + def test_strict_rejects_bare_string(self): + """A bare string was the bug's silent-pass: must now be rejected.""" + mt = MelleaTool.from_callable(act) + with pytest.raises(ValidationError): + validate_tool_arguments(mt, {"pet": "just a string"}, strict=True) + + def test_strict_rejects_missing_discriminator(self): + """A dict without the ``kind`` discriminator must be rejected.""" + mt = MelleaTool.from_callable(act) + with pytest.raises(ValidationError): + validate_tool_arguments(mt, {"pet": {"name": "Rex"}}, strict=True) + + def test_optional_accepts_omitted(self): + """The optional variant accepts the parameter being omitted.""" + mt = MelleaTool.from_callable(act_optional) + validate_tool_arguments(mt, {}, strict=True) + + def test_optional_accepts_valid_payload(self): + """The optional variant accepts a valid payload.""" + mt = MelleaTool.from_callable(act_optional) + validate_tool_arguments( + mt, {"pet": {"kind": "dog", "name": "Rex", "breed": "lab"}}, strict=True + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/backends/test_schema_helpers.py b/test/backends/test_schema_helpers.py index 73dced767..193ca6c28 100644 --- a/test/backends/test_schema_helpers.py +++ b/test/backends/test_schema_helpers.py @@ -163,6 +163,38 @@ def test_complex_nested_structure(self): } assert _is_complex_anyof(schema) is True + def test_anyof_with_oneof_branch_is_complex(self): + """anyOf containing a oneOf branch (Optional discriminated union) is complex. + + Pydantic emits ``Annotated[Cat | Dog, Field(discriminator="kind")] | None`` + as ``{"anyOf": [{"discriminator": {...}, "oneOf": [$refs]}, {"type": "null"}]}``. + Without descending into ``oneOf`` we drop the union structure and ship + ``{"type": "string"}`` to the backend (issue #989). + """ + schema = { + "anyOf": [ + { + "discriminator": { + "propertyName": "kind", + "mapping": {"cat": "#/$defs/Cat", "dog": "#/$defs/Dog"}, + }, + "oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}], + }, + {"type": "null"}, + ] + } + assert _is_complex_anyof(schema) is True + + def test_anyof_with_plain_oneof_branch_is_complex(self): + """anyOf containing a oneOf branch without discriminator metadata is complex.""" + schema = { + "anyOf": [ + {"oneOf": [{"$ref": "#/$defs/Cat"}, {"$ref": "#/$defs/Dog"}]}, + {"type": "null"}, + ] + } + assert _is_complex_anyof(schema) is True + if __name__ == "__main__": pytest.main([__file__]) From c311c2cf8033c690e83b661fb624232b8f375b4a Mon Sep 17 00:00:00 2001 From: Nigel Jones Date: Wed, 20 May 2026 12:34:53 +0100 Subject: [PATCH 2/2] fix: address review feedback on discriminated-union flattening MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rewrite ``_flatten_discriminated_union`` to be non-mutating and update the docstring; the previous "rewrites in place" claim was misleading because the required-union path returned a new dict. Document the single-level limitation (nested discriminated unions are not recursively flattened — tracked alongside #911). - Defensively merge ``oneOf`` into any pre-existing top-level ``anyOf`` rather than overwriting, so the helper is safe to call in isolation even on shapes Pydantic does not currently emit. - Drop a redundant ``v.get("anyOf", [])`` default whose key existence was already guaranteed by the surrounding guard. Tests: - ``test_optional_union_strips_discriminator_keyword`` — pin the implicit-strip in the optional path so a refactor can't silently reintroduce the OAS-3 keyword. - ``test_three_way_union_preserves_all_branches`` — three-arm discriminated unions are common in command-pattern tools. - ``test_non_discriminated_optional_unchanged`` — regression guard for the existing ``Optional[Email]`` flow; the new pre-pass must be a no-op. - Tighten ``_has_branch`` to ``anyOf`` only; accepting ``oneOf`` as a fallback would silently mask a regression of the flattening pre-pass. - Move ``import json`` to module top. Assisted-by: Claude Code Signed-off-by: Nigel Jones --- mellea/backends/tools.py | 36 ++++--- .../test_discriminated_union_tools.py | 99 ++++++++++++++++++- 2 files changed, 117 insertions(+), 18 deletions(-) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index a32374067..c81a59b6f 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -1003,9 +1003,14 @@ def _flatten_discriminated_union(v: dict, defs: dict) -> dict: discriminator signal, so we can safely drop ``discriminator`` and emit ``anyOf`` — equivalent here because the tag field forces uniqueness. - This rewrites the schema in place: ``oneOf`` becomes ``anyOf`` with branches - inlined, ``discriminator`` is stripped, and Optional unions flatten to - ``{"anyOf": [...inlined_branches, {"type": "null"}]}``. + Returns a new schema dict with ``oneOf`` rewritten to ``anyOf``, branches + inlined, and ``discriminator`` stripped. The input is not mutated; callers + must reassign the result. + + Flattening is single-level. Discriminated unions nested inside an inlined + branch (e.g. a Pydantic model whose own field is another discriminated + union) are not recursively flattened — that case is tracked alongside the + recursive ``$ref`` resolution work in #911. """ def _inline(branch: dict) -> dict: @@ -1015,24 +1020,29 @@ def _inline(branch: dict) -> dict: return copy.deepcopy(ref_schema) return copy.deepcopy(branch) - # Top-level discriminated union (required parameter case) + out = {kk: vv for kk, vv in v.items() if kk not in ("oneOf", "discriminator")} + + # Top-level discriminated union (required parameter case). Append rather + # than overwrite so a defensive ``oneOf`` + ``anyOf`` co-occurrence does + # not silently drop the existing ``anyOf`` entries. Pydantic does not + # currently emit both at the same level, but the helper is callable in + # isolation and should not lose data. if "oneOf" in v: - branches = [_inline(b) for b in v["oneOf"]] - v = {kk: vv for kk, vv in v.items() if kk not in ("oneOf", "discriminator")} - v["anyOf"] = branches + existing = list(out.get("anyOf", [])) + out["anyOf"] = existing + [_inline(b) for b in v["oneOf"]] # Nested oneOf inside anyOf (Optional discriminated union case) - if "anyOf" in v: + if "anyOf" in out: flattened: list[dict] = [] - for sub in v["anyOf"]: + for sub in out["anyOf"]: if "oneOf" in sub: for branch in sub["oneOf"]: flattened.append(_inline(branch)) else: flattened.append(sub) - v["anyOf"] = flattened + out["anyOf"] = flattened - return v + return out # https://github.com/ollama/ollama-python/blob/60e7b2f9ce710eeb57ef2986c46ea612ae7516af/ollama/_utils.py#L56-L90 @@ -1071,9 +1081,7 @@ def convert_function_to_ollama_tool( for k, v in schema.get("properties", {}).items(): # Pre-pass: flatten Pydantic discriminated unions (oneOf + discriminator) # into plain anyOf with branches inlined. See _flatten_discriminated_union. - if "oneOf" in v or ( - "anyOf" in v and any("oneOf" in s for s in v.get("anyOf", [])) - ): + if "oneOf" in v or ("anyOf" in v and any("oneOf" in s for s in v["anyOf"])): v = _flatten_discriminated_union(v, defs) schema["properties"][k] = v diff --git a/test/backends/test_discriminated_union_tools.py b/test/backends/test_discriminated_union_tools.py index a07adf181..92305ae06 100644 --- a/test/backends/test_discriminated_union_tools.py +++ b/test/backends/test_discriminated_union_tools.py @@ -10,6 +10,7 @@ include it; the ``Literal`` tag fields carry the discriminator signal). """ +import json from typing import Annotated, Literal import pytest @@ -33,6 +34,19 @@ class Dog(BaseModel): breed: str +class Fish(BaseModel): + kind: Literal["fish"] + name: str + species: str + + +class Email(BaseModel): + """Non-discriminated nested model for the no-op regression test.""" + + to: str + subject: str + + def act(pet: Annotated[Cat | Dog, Field(discriminator="kind")]) -> str: """Act on a pet. @@ -56,12 +70,19 @@ def act_optional( def _pet_schema(func) -> dict: """Convert ``func`` and return the ``pet`` parameter schema.""" tool = convert_function_to_ollama_tool(func) + assert tool.function is not None + assert tool.function.parameters is not None return tool.function.parameters.model_dump(exclude_none=True)["properties"]["pet"] def _has_branch(schema: dict, kind_value: str, *, must_have: set[str]) -> bool: - """Check that ``schema`` contains an inlined object branch for ``kind_value``.""" - branches = schema.get("anyOf") or schema.get("oneOf") or [] + """Check that ``schema`` contains an inlined ``anyOf`` branch for ``kind_value``. + + After the fix lands the output schema must contain ``anyOf`` only, never + ``oneOf`` — accepting ``oneOf`` here would silently mask a regression of + the discriminator-flattening pre-pass. + """ + branches = schema.get("anyOf", []) for branch in branches: props = branch.get("properties", {}) kind = props.get("kind", {}) @@ -104,8 +125,6 @@ def test_required_union_strips_discriminator_keyword(self): def test_required_union_no_dangling_refs(self): """No ``$ref`` should leak into the output for the issue reproducer.""" - import json - rendered = json.dumps(_pet_schema(act)) assert "$ref" not in rendered, f"unresolved $ref in tool schema: {rendered}" @@ -131,11 +150,83 @@ def test_optional_union_preserves_branches(self): def test_optional_union_drops_from_required(self): """The optional parameter must not be in the function's required list.""" tool = convert_function_to_ollama_tool(act_optional) + assert tool.function is not None + assert tool.function.parameters is not None params = tool.function.parameters.model_dump(exclude_none=True) assert "pet" not in params.get("required", []), ( f"optional 'pet' should not be required: {params}" ) + def test_optional_union_strips_discriminator_keyword(self): + """The Optional variant must also drop the OAS-3 ``discriminator``. + + The required variant strips it via the top-level ``oneOf`` path; the + optional variant strips it implicitly when the wrapper sub-schema is + replaced by its expanded branches. Asserted explicitly so a refactor + that re-introduces the wrapper does not slip past silently. + """ + rendered = json.dumps(_pet_schema(act_optional)) + assert "discriminator" not in rendered, ( + f"discriminator keyword should be stripped from optional output: {rendered}" + ) + + def test_three_way_union_preserves_all_branches(self): + """A three-arm discriminated union must preserve all three branches.""" + + def act_three( + pet: Annotated[Cat | Dog | Fish, Field(discriminator="kind")], + ) -> str: + """Act on a three-way pet. + + Args: + pet: the pet to act on + """ + return "ok" + + pet = _pet_schema(act_three) + assert _has_branch(pet, "cat", must_have={"kind", "name"}), ( + f"Cat branch missing in three-way union: {pet!r}" + ) + assert _has_branch(pet, "dog", must_have={"kind", "name", "breed"}), ( + f"Dog branch missing in three-way union: {pet!r}" + ) + assert _has_branch(pet, "fish", must_have={"kind", "name", "species"}), ( + f"Fish branch missing in three-way union: {pet!r}" + ) + + def test_non_discriminated_optional_unchanged(self): + """Non-discriminated ``Optional[Email]`` must still flow through unchanged. + + Regression guard: the new pre-pass must be a no-op for plain + ``$ref`` + ``| None`` shapes that the existing inliner already + handles. Pydantic emits this as + ``{"anyOf": [{"$ref": "..."}, {"type": "null"}]}`` — no ``oneOf`` + in any sub-schema, so the pre-pass should not activate. + """ + + def send(email: Email | None = None) -> str: + """Send an email. + + Args: + email: optional email payload + """ + return "sent" + + tool = convert_function_to_ollama_tool(send) + assert tool.function is not None + assert tool.function.parameters is not None + rendered = tool.function.parameters.model_dump(exclude_none=True) + email_schema = rendered["properties"]["email"] + # The existing complex-anyOf path inlines the $ref and preserves the + # full object schema with properties. The exact shape is owned by the + # pre-existing logic; we only assert the pre-pass did not collapse it. + assert email_schema.get("type") != "string", ( + f"non-discriminated Optional collapsed: {email_schema!r}" + ) + assert "email" not in rendered.get("required", []), ( + f"optional email should not be required: {rendered}" + ) + class TestDiscriminatedUnionValidation: """``validate_tool_arguments`` must round-trip a valid discriminated payload."""