diff --git a/CHANGELOG.md b/CHANGELOG.md index 935890f..6854f5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to `uipath_llm_client` (core package) will be documented in this file. +## [1.10.0] - 2026-04-23 + +### Added +- `uipath.llm_client.utils.sampling` module exposing `DISABLED_SAMPLING_PARAMS`, `disabled_params_from_model_details`, `is_disabled_value`, and `strip_disabled_kwargs`. The helpers use the langchain-openai-style `disabled_params` format (`{name: None | [values]}`) so they compose with the existing `langchain_openai._filter_disabled_params` path. `disabled_params_from_model_details` derives the disabled-param map from a discovery-endpoint `modelDetails` dict (today: `shouldSkipTemperature=True` disables the full sampling set — temperature, top_p, top_k, frequency/presence penalty, seed, logit_bias, logprobs, top_logprobs). + ## [1.9.9] - 2026-04-23 ### Changed diff --git a/packages/uipath_langchain_client/CHANGELOG.md b/packages/uipath_langchain_client/CHANGELOG.md index 5701366..372b957 100644 --- a/packages/uipath_langchain_client/CHANGELOG.md +++ b/packages/uipath_langchain_client/CHANGELOG.md @@ -2,6 +2,22 @@ All notable changes to `uipath_langchain_client` will be documented in this file. +## [1.10.0] - 2026-04-23 + +### Added +- `model_details` and `disabled_params` fields on `UiPathBaseLLMClient`, plus a single `@model_validator(mode="after") setup_model_info` that (1) forwards the factory-supplied `model_details` or fetches it from `client_settings.get_model_info`, and (2) sets `disabled_params` to the merge of what the caller passed and what `disabled_params_from_model_details` derives — user keys win on conflicts, so callers can override any derived entry by name. +- `disabled_params` uses the langchain-openai shape (`{name: None | [values]}`), so subclasses inheriting from `ChatOpenAI` / `AzureChatOpenAI` also benefit from the native `_filter_disabled_params` path inside `bind_tools`. +- Runtime stripping in the four `_generate`/`_agenerate`/`_stream`/`_astream` wrappers on `UiPathBaseChatModel` delegates to `uipath.llm_client.utils.sampling.strip_disabled_kwargs`, generic over `disabled_params`. A warning is logged via `self.logger` for each stripped key when a logger is configured. Fixes `anthropic.claude-opus-4-7` rejecting any sampling parameter passed via `.invoke()` / `.ainvoke()` / streams. + +### Removed +- The unused `disabled_params` field declaration on `UiPathChat` (now inherited from `UiPathBaseLLMClient`). + +### Changed +- Bumped `uipath-llm-client` floor to `>=1.10.0` to match the release that adds `uipath.llm_client.utils.sampling`. + +### Known follow-up +- Init-time values set on the instance (`UiPathChat(model="anthropic.claude-opus-4-7", temperature=0.5)`) still flow into the outgoing request body via `_default_params` / the vendor SDK. The runtime invoke-time strip handles `.invoke(..., temperature=...)`; a follow-up will plug the init-time leak using the already-populated `disabled_params`. + ## [1.9.9] - 2026-04-23 ### Changed diff --git a/packages/uipath_langchain_client/pyproject.toml b/packages/uipath_langchain_client/pyproject.toml index bca0acc..056c4ef 100644 --- a/packages/uipath_langchain_client/pyproject.toml +++ b/packages/uipath_langchain_client/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "langchain>=1.2.15,<2.0.0", - "uipath-llm-client>=1.9.9,<2.0.0", + "uipath-llm-client>=1.10.0,<2.0.0", ] [project.optional-dependencies] diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py index b612be4..afec5ef 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py @@ -1,3 +1,3 @@ __title__ = "UiPath LangChain Client" __description__ = "A Python client for interacting with UiPath's LLM services via LangChain." -__version__ = "1.9.9" +__version__ = "1.10.0" diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py b/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py index 91274bd..a3ee366 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py @@ -27,7 +27,7 @@ from abc import ABC from collections.abc import AsyncGenerator, Generator, Mapping, Sequence from functools import cached_property -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, Self from httpx import URL, Response from langchain_core.callbacks import ( @@ -38,7 +38,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import AliasChoices, BaseModel, ConfigDict, Field +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator from uipath.llm_client.httpx_client import ( UiPathHttpxAsyncClient, @@ -49,6 +49,10 @@ get_captured_response_headers, set_captured_response_headers, ) +from uipath.llm_client.utils.sampling import ( + disabled_params_from_model_details, + strip_disabled_kwargs, +) from uipath_langchain_client.settings import ( UiPathAPIConfig, UiPathBaseSettings, @@ -108,6 +112,19 @@ class UiPathBaseLLMClient(BaseModel, ABC): description="Settings for the UiPath client (defaults based on UIPATH_LLM_SERVICE env var)", ) + model_details: dict[str, Any] | None = Field( + default=None, + description="Per-model capability flags sourced from the discovery endpoint " + "(e.g. {'shouldSkipTemperature': True}). Passed through by the factory; " + "resolved from client_settings.get_model_info otherwise.", + ) + disabled_params: dict[str, Any] | None = Field( + default=None, + description="langchain-openai-style map of parameters that must not be sent to " + "this model. Keys are param names; values are None (always disabled) or a list " + "of disallowed values. Derived from ``model_details`` when not provided.", + ) + default_headers: Mapping[str, str] | None = Field( default=None, description="Caller-supplied request headers. Merged on top of `class_default_headers`; " @@ -140,6 +157,39 @@ class UiPathBaseLLMClient(BaseModel, ABC): description="Logger for request/response logging", ) + @model_validator(mode="after") + def setup_model_info(self) -> Self: + """Resolve ``model_details`` from discovery and merge ``disabled_params``. + + Runs after pydantic has validated the fields, so ``self.client_settings`` + (with its ``default_factory``) and ``self.model_name`` are already live. + + ``model_details`` is resolved once: caller-forwarded value wins, then a + lookup against ``client_settings.get_model_info`` (backed by the + class-cached discovery response), else an empty mapping on failure. + + ``disabled_params`` is the merge of what the caller passed and what we + can derive from ``model_details`` (via + ``disabled_params_from_model_details``). User-provided keys win on + conflicts, so callers can override a derived entry by name. + """ + if self.model_details is None: + try: + info = self.client_settings.get_model_info( + self.model_name, + byo_connection_id=self.byo_connection_id, + ) + self.model_details = info.get("modelDetails") or {} + except Exception: + self.model_details = {} + + derived = disabled_params_from_model_details(self.model_details) or {} + user_provided = self.disabled_params or {} + merged = {**derived, **user_provided} + self.disabled_params = merged or None + + return self + @cached_property def uipath_sync_client(self) -> UiPathHttpxClient: """Here we instantiate a synchronous HTTP client with the proper authentication pipeline, retry logic, logging etc.""" @@ -364,6 +414,12 @@ def _generate( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: + kwargs = strip_disabled_kwargs( + kwargs, + disabled_params=self.disabled_params, + model_name=self.model_name, + logger=self.logger, + ) set_captured_response_headers({}) try: result = self._uipath_generate(messages, stop=stop, run_manager=run_manager, **kwargs) @@ -389,6 +445,12 @@ async def _agenerate( run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: + kwargs = strip_disabled_kwargs( + kwargs, + disabled_params=self.disabled_params, + model_name=self.model_name, + logger=self.logger, + ) set_captured_response_headers({}) try: result = await self._uipath_agenerate( @@ -416,6 +478,12 @@ def _stream( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Generator[ChatGenerationChunk, None, None]: + kwargs = strip_disabled_kwargs( + kwargs, + disabled_params=self.disabled_params, + model_name=self.model_name, + logger=self.logger, + ) set_captured_response_headers({}) try: first = True @@ -446,6 +514,12 @@ async def _astream( run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> AsyncGenerator[ChatGenerationChunk, None]: + kwargs = strip_disabled_kwargs( + kwargs, + disabled_params=self.disabled_params, + model_name=self.model_name, + logger=self.logger, + ) set_captured_response_headers({}) try: first = True diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/litellm/embeddings.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/litellm/embeddings.py index 9648e22..263823b 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/litellm/embeddings.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/litellm/embeddings.py @@ -10,8 +10,6 @@ >>> vectors = embeddings.embed_documents(["Hello world"]) """ -from __future__ import annotations - from pydantic import Field, model_validator from typing_extensions import Self diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py index 88a7040..bfaeb15 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py @@ -64,9 +64,9 @@ from langchain_core.utils.pydantic import is_basemodel_subclass from pydantic import AliasChoices, BaseModel, Field +from uipath.llm_client.utils.model_family import is_anthropic_model_name from uipath_langchain_client.base_client import UiPathBaseChatModel from uipath_langchain_client.settings import ApiType, RoutingMode, UiPathAPIConfig -from uipath_langchain_client.utils import is_anthropic_model_name _DictOrPydanticClass = Union[dict[str, Any], type[BaseModel], type] _DictOrPydantic = Union[dict[str, Any], BaseModel] @@ -179,7 +179,6 @@ class UiPathChat(UiPathBaseChatModel): seed: int | None = None model_kwargs: dict[str, Any] = Field(default_factory=dict) - disabled_params: dict[str, Any] | None = None # OpenAI logit_bias: dict[str, int] | None = None diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py index 8ef3b99..f0e0576 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py @@ -22,6 +22,7 @@ from typing import Any +from uipath.llm_client.utils.model_family import is_anthropic_model_name from uipath_langchain_client.base_client import ( UiPathBaseChatModel, UiPathBaseEmbeddings, @@ -36,7 +37,6 @@ VendorType, get_default_client_settings, ) -from uipath_langchain_client.utils import is_anthropic_model_name def get_chat_model( @@ -85,12 +85,14 @@ def get_chat_model( vendor_type=vendor_type, ) model_family = model_info.get("modelFamily", None) + model_details = model_info.get("modelDetails") or {} if custom_class is not None: return custom_class( model=model_name, settings=client_settings, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) @@ -103,6 +105,7 @@ def get_chat_model( model=model_name, settings=client_settings, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) @@ -138,6 +141,7 @@ def get_chat_model( settings=client_settings, api_flavor=api_flavor, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) else: @@ -150,6 +154,7 @@ def get_chat_model( settings=client_settings, api_flavor=api_flavor, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) case VendorType.VERTEXAI: @@ -163,6 +168,7 @@ def get_chat_model( settings=client_settings, vendor_type=discovered_vendor_type, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) @@ -174,6 +180,7 @@ def get_chat_model( model=model_name, settings=client_settings, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) case VendorType.AWSBEDROCK: @@ -188,6 +195,7 @@ def get_chat_model( model=model_name, settings=client_settings, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) @@ -200,6 +208,7 @@ def get_chat_model( model=model_name, settings=client_settings, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) @@ -211,6 +220,7 @@ def get_chat_model( model=model_name, settings=client_settings, byo_connection_id=byo_connection_id, + model_details=model_details, **model_kwargs, ) diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/utils.py b/packages/uipath_langchain_client/src/uipath_langchain_client/utils.py index 0607b05..b49e05a 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/utils.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/utils.py @@ -13,10 +13,6 @@ UiPathTooManyRequestsError, UiPathUnprocessableEntityError, ) -from uipath.llm_client.utils.model_family import ( - ANTHROPIC_MODEL_NAME_KEYWORDS, - is_anthropic_model_name, -) from uipath.llm_client.utils.retry import RetryConfig __all__ = [ @@ -34,6 +30,4 @@ "UiPathServiceUnavailableError", "UiPathGatewayTimeoutError", "UiPathTooManyRequestsError", - "ANTHROPIC_MODEL_NAME_KEYWORDS", - "is_anthropic_model_name", ] diff --git a/src/uipath/llm_client/__version__.py b/src/uipath/llm_client/__version__.py index bd9af64..e4a0dca 100644 --- a/src/uipath/llm_client/__version__.py +++ b/src/uipath/llm_client/__version__.py @@ -1,3 +1,3 @@ __title__ = "UiPath LLM Client" __description__ = "A Python client for interacting with UiPath's LLM services." -__version__ = "1.9.9" +__version__ = "1.10.0" diff --git a/src/uipath/llm_client/clients/normalized/completions.py b/src/uipath/llm_client/clients/normalized/completions.py index 35b91a3..4bd63b5 100644 --- a/src/uipath/llm_client/clients/normalized/completions.py +++ b/src/uipath/llm_client/clients/normalized/completions.py @@ -1,7 +1,5 @@ """Completions endpoint for the UiPath Normalized API.""" -from __future__ import annotations - import json from collections.abc import AsyncGenerator, Callable, Generator, Sequence from typing import Any, Union, get_args, get_origin, get_type_hints diff --git a/src/uipath/llm_client/clients/normalized/embeddings.py b/src/uipath/llm_client/clients/normalized/embeddings.py index 9caf92c..44f6596 100644 --- a/src/uipath/llm_client/clients/normalized/embeddings.py +++ b/src/uipath/llm_client/clients/normalized/embeddings.py @@ -3,8 +3,6 @@ Provides synchronous and asynchronous methods for generating text embeddings. """ -from __future__ import annotations - from typing import Any from uipath.llm_client.clients.normalized.types import ( diff --git a/src/uipath/llm_client/utils/sampling.py b/src/uipath/llm_client/utils/sampling.py new file mode 100644 index 0000000..18d1d51 --- /dev/null +++ b/src/uipath/llm_client/utils/sampling.py @@ -0,0 +1,96 @@ +"""Helpers for the ``disabled_params`` convention. + +``disabled_params`` is the langchain-openai-style declaration that certain +parameters must not be sent to a model. It maps param names to either: + +- ``None``: the parameter is always disabled, regardless of its value. +- ``list[Any]``: the parameter is disabled only when its value is in the list. + +We reuse this shape so that classes inheriting from +``langchain_openai.BaseChatOpenAI`` also benefit from its native +``_filter_disabled_params`` path inside ``bind_tools``. + +The sampling-specific knowledge lives in ``disabled_params_from_model_details``: +when the gateway's discovery endpoint advertises +``modelDetails.shouldSkipTemperature: true`` on a reasoning-style model (e.g. +``anthropic.claude-opus-4-7``), the entire sampling set gets disabled. +""" + +from collections.abc import Mapping +from logging import Logger +from typing import Any + +# Parameters the gateway rejects when ``shouldSkipTemperature`` is true. +# ``n`` (candidate count) is intentionally NOT here — it is not a sampling knob. +DISABLED_SAMPLING_PARAMS: tuple[str, ...] = ( + "temperature", + "top_p", + "top_k", + "frequency_penalty", + "presence_penalty", + "seed", + "logit_bias", + "logprobs", + "top_logprobs", +) + + +def disabled_params_from_model_details( + model_details: Mapping[str, Any] | None, +) -> dict[str, Any] | None: + """Derive ``disabled_params`` from a discovery-endpoint ``modelDetails`` dict. + + Returns None when no capability flags warrant disabling anything, so callers + can distinguish "nothing to disable" from "disabled empty mapping". + """ + if not model_details: + return None + disabled: dict[str, Any] = {} + if model_details.get("shouldSkipTemperature"): + for param in DISABLED_SAMPLING_PARAMS: + disabled[param] = None + # Future gateway flags (e.g. per-param ``shouldSkipTopP``) can extend this. + return disabled or None + + +def is_disabled_value(value: Any, disabled_spec: Any) -> bool: + """Match the langchain-openai ``_filter_disabled_params`` semantics. + + ``disabled_spec`` is either None (always disabled) or an iterable of values + (disabled only when ``value`` is in the iterable). + """ + if disabled_spec is None: + return True + try: + return value in disabled_spec + except TypeError: + return False + + +def strip_disabled_kwargs( + kwargs: Mapping[str, Any], + *, + disabled_params: Mapping[str, Any] | None, + model_name: str, + logger: Logger | None, +) -> dict[str, Any]: + """Return a copy of ``kwargs`` with entries matching ``disabled_params`` removed. + + Uses the same matching rule as langchain-openai: a key is stripped when it + is in ``disabled_params`` AND either the spec is None or the kwarg value + matches one of the listed disabled values. Logs a warning per strip if a + logger is supplied; silent otherwise. + """ + out = dict(kwargs) + if not disabled_params: + return out + for key in list(out.keys()): + if key in disabled_params and is_disabled_value(out[key], disabled_params[key]): + if logger is not None: + logger.warning( + "Stripping disabled invocation param %r for model %r", + key, + model_name, + ) + out.pop(key, None) + return out diff --git a/tests/langchain/test_disabled_sampling_params.py b/tests/langchain/test_disabled_sampling_params.py new file mode 100644 index 0000000..712400d --- /dev/null +++ b/tests/langchain/test_disabled_sampling_params.py @@ -0,0 +1,558 @@ +"""Unit tests for the ``disabled_params`` + ``model_details`` wiring. + +Covers two layers: + +1. **Metadata resolution** via ``@model_validator(mode="after")`` on + ``UiPathBaseLLMClient``: ``model_details`` is forwarded by the factory or + fetched from ``client_settings.get_model_info``; ``disabled_params`` is the + merge of anything the caller passed and what we can derive from + ``model_details``. + +2. **Invocation-time stripping** via ``strip_disabled_kwargs`` wired into + ``_generate``/``_agenerate``/``_stream``/``_astream`` on + ``UiPathBaseChatModel``. + +Tests monkeypatch ``client_settings.get_model_info`` and stub the +``_uipath_generate``/``_uipath_agenerate`` methods so no HTTP is ever made. +""" + +import logging +from typing import Any + +import pytest +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from uipath_langchain_client.clients.normalized.chat_models import UiPathChat +from uipath_langchain_client.clients.openai.chat_models import ( + UiPathAzureChatOpenAI, + UiPathChatOpenAI, +) +from uipath_langchain_client.factory import get_chat_model + +from uipath.llm_client.settings import UiPathBaseSettings +from uipath.llm_client.utils.sampling import DISABLED_SAMPLING_PARAMS + +# --------------------------------------------------------------------------- # +# helpers +# --------------------------------------------------------------------------- # + + +def _stub_model_info( + monkeypatch: pytest.MonkeyPatch, + settings: UiPathBaseSettings, + *, + model_details: dict[str, Any] | None = None, + extra: dict[str, Any] | None = None, + raises: BaseException | None = None, +) -> None: + """Replace ``client_settings.get_model_info`` with a stub.""" + + def _stub(model_name: str, **kwargs: Any) -> dict[str, Any]: + if raises is not None: + raise raises + info: dict[str, Any] = { + "modelName": model_name, + "vendor": "AwsBedrock", + "modelSubscriptionType": "UiPathOwned", + "modelDetails": model_details, + } + if extra: + info.update(extra) + return info + + monkeypatch.setattr(settings, "get_model_info", _stub) + + +def _stub_generate( + monkeypatch: pytest.MonkeyPatch, instance: UiPathChat, captured: dict[str, Any] +) -> None: + def _stub( + messages: Any, stop: Any = None, run_manager: Any = None, **kwargs: Any + ) -> ChatResult: + captured.update(kwargs) + captured["__stop__"] = stop + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="ok"))]) + + monkeypatch.setattr(instance, "_uipath_generate", _stub) + + +def _stub_agenerate( + monkeypatch: pytest.MonkeyPatch, instance: UiPathChat, captured: dict[str, Any] +) -> None: + async def _stub( + messages: Any, stop: Any = None, run_manager: Any = None, **kwargs: Any + ) -> ChatResult: + captured.update(kwargs) + captured["__stop__"] = stop + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="ok"))]) + + monkeypatch.setattr(instance, "_uipath_agenerate", _stub) + + +# --------------------------------------------------------------------------- # +# metadata resolution (model_details + disabled_params) +# --------------------------------------------------------------------------- # + + +def test_disabled_params_derived_from_model_details_flag( + client_settings: UiPathBaseSettings, +) -> None: + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + assert llm.disabled_params is not None + assert set(llm.disabled_params) == set(DISABLED_SAMPLING_PARAMS) + + +def test_user_provided_disabled_params_merges_with_derived( + client_settings: UiPathBaseSettings, +) -> None: + # modelDetails derives the sampling set; caller adds an extra key + # (logit_bias is already in the sampling set, so we use stream_usage + # to demonstrate a truly additive merge). + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + disabled_params={"stream_usage": None}, + ) + assert llm.disabled_params is not None + assert set(llm.disabled_params) == set(DISABLED_SAMPLING_PARAMS) | {"stream_usage"} + + +def test_user_provided_disabled_params_overrides_derived_entry( + client_settings: UiPathBaseSettings, +) -> None: + # If the caller supplies a narrower spec for an already-derived key + # (e.g. only disable temperature when value is 0.0), their spec wins. + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + disabled_params={"temperature": [0.0]}, + ) + assert llm.disabled_params is not None + assert llm.disabled_params["temperature"] == [0.0] + # Other derived keys remain disabled unconditionally. + assert llm.disabled_params["top_p"] is None + + +def test_no_disabled_params_when_flag_absent( + client_settings: UiPathBaseSettings, +) -> None: + llm = UiPathChat( + model="some-chatty-model", + settings=client_settings, + model_details={}, + ) + assert llm.disabled_params is None + + +# --------------------------------------------------------------------------- # +# invocation-time stripping +# --------------------------------------------------------------------------- # + + +def test_invoke_strips_sampling_kwargs_when_flag_set( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + llm.invoke("hi", temperature=0.3, top_p=0.9, top_k=5, seed=42, max_tokens=100) + + for p in ("temperature", "top_p", "top_k", "seed"): + assert p not in captured, f"{p} should have been stripped" + assert captured.get("max_tokens") == 100 + + +def test_invoke_strips_every_listed_sampling_param( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + kwargs: dict[str, Any] = {p: 0.1 for p in DISABLED_SAMPLING_PARAMS} + kwargs["max_tokens"] = 50 + llm.invoke("x", **kwargs) # type: ignore[arg-type] + + for p in DISABLED_SAMPLING_PARAMS: + assert p not in captured + assert captured["max_tokens"] == 50 + + +def test_n_is_not_stripped( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + # `n` (candidate count) is intentionally NOT part of DISABLED_SAMPLING_PARAMS. + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + llm.invoke("x", n=3) + assert captured.get("n") == 3 + + +@pytest.mark.asyncio +async def test_ainvoke_strips_sampling_kwargs_when_flag_set( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + captured: dict[str, Any] = {} + _stub_agenerate(monkeypatch, llm, captured) + + await llm.ainvoke("hi", temperature=0.3, top_p=0.9) + + assert "temperature" not in captured + assert "top_p" not in captured + + +def test_invoke_preserves_kwargs_when_flag_absent( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + llm = UiPathChat( + model="some-chatty-model", + settings=client_settings, + model_details={}, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + llm.invoke("hi", temperature=0.3, top_p=0.9) + + assert captured["temperature"] == 0.3 + assert captured["top_p"] == 0.9 + + +def test_invoke_honors_user_supplied_disabled_params( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + llm = UiPathChat( + model="some-chatty-model", + settings=client_settings, + model_details={}, + disabled_params={"frequency_penalty": None}, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + llm.invoke("x", temperature=0.3, frequency_penalty=1.0) + + assert captured.get("temperature") == 0.3 + assert "frequency_penalty" not in captured + + +def test_invoke_honors_disabled_params_value_list( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + # langchain-openai semantics: list spec means "disabled when value is in list". + llm = UiPathChat( + model="some-chatty-model", + settings=client_settings, + model_details={}, + disabled_params={"temperature": [0.0, 1.5]}, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + llm.invoke("x", temperature=1.5) # matches -> stripped + assert "temperature" not in captured + + captured.clear() + llm.invoke("x", temperature=0.7) # does not match -> preserved + assert captured.get("temperature") == 0.7 + + +# --------------------------------------------------------------------------- # +# warning gating via self.logger +# --------------------------------------------------------------------------- # + + +def test_warning_logged_when_logger_set( + monkeypatch: pytest.MonkeyPatch, + client_settings: UiPathBaseSettings, + caplog: pytest.LogCaptureFixture, +) -> None: + logger = logging.getLogger("uipath.test.skip-sampling") + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + logger=logger, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + with caplog.at_level(logging.WARNING, logger=logger.name): + llm.invoke("x", temperature=0.3) + + assert any( + "temperature" in rec.getMessage() and "disabled" in rec.getMessage() + for rec in caplog.records + ), "expected a warning mentioning 'temperature' and 'disabled'" + + +def test_no_warning_when_logger_is_none( + monkeypatch: pytest.MonkeyPatch, + client_settings: UiPathBaseSettings, + caplog: pytest.LogCaptureFixture, +) -> None: + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + logger=None, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + with caplog.at_level(logging.DEBUG): + llm.invoke("x", temperature=0.3) + + assert "temperature" not in captured + assert not any("disabled invocation param" in rec.getMessage() for rec in caplog.records) + + +def test_no_warning_when_nothing_to_strip( + monkeypatch: pytest.MonkeyPatch, + client_settings: UiPathBaseSettings, + caplog: pytest.LogCaptureFixture, +) -> None: + logger = logging.getLogger("uipath.test.skip-sampling-quiet") + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + logger=logger, + ) + captured: dict[str, Any] = {} + _stub_generate(monkeypatch, llm, captured) + + with caplog.at_level(logging.WARNING, logger=logger.name): + llm.invoke("x", max_tokens=50) + + assert not any("disabled invocation param" in rec.getMessage() for rec in caplog.records) + + +# --------------------------------------------------------------------------- # +# discovery fallback +# --------------------------------------------------------------------------- # + + +def test_validator_fetches_model_details_when_not_provided( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + _stub_model_info(monkeypatch, client_settings, model_details={"shouldSkipTemperature": True}) + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + ) + # model_details resolved from discovery; disabled_params derived. + assert llm.model_details == {"shouldSkipTemperature": True} + assert llm.disabled_params is not None + assert "temperature" in llm.disabled_params + + +def test_validator_swallows_discovery_errors( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + _stub_model_info(monkeypatch, client_settings, raises=RuntimeError("boom")) + llm = UiPathChat( + model="anthropic.claude-opus-4-7", + settings=client_settings, + temperature=0.5, + ) + # Discovery failure => model_details is {} and nothing is stripped. + assert llm.model_details == {} + assert llm.disabled_params is None + assert llm.temperature == 0.5 + + +# --------------------------------------------------------------------------- # +# factory forwarding +# --------------------------------------------------------------------------- # + + +def test_factory_forwards_model_details( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + from uipath_langchain_client.settings import RoutingMode + + _stub_model_info(monkeypatch, client_settings, model_details={"shouldSkipTemperature": True}) + + llm = get_chat_model( + "anthropic.claude-opus-4-7", + client_settings=client_settings, + routing_mode=RoutingMode.NORMALIZED, + ) + + assert isinstance(llm, UiPathChat) + assert llm.model_details == {"shouldSkipTemperature": True} + assert llm.disabled_params is not None + assert "temperature" in llm.disabled_params + + +def test_factory_forwards_empty_dict_when_no_model_details( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + from uipath_langchain_client.settings import RoutingMode + + _stub_model_info(monkeypatch, client_settings, model_details=None) + + llm = get_chat_model( + "gpt-4o", + client_settings=client_settings, + routing_mode=RoutingMode.NORMALIZED, + ) + + assert isinstance(llm, UiPathChat) + assert llm.model_details == {} + assert llm.disabled_params is None + + +# --------------------------------------------------------------------------- # +# interop with langchain-openai's native ``disabled_params`` +# --------------------------------------------------------------------------- # +# +# ``BaseChatOpenAI`` (and thus ``UiPathChatOpenAI``) already declares +# ``disabled_params: dict[str, Any] | None`` with the same shape we use. These +# tests pin down: +# - Our ``setup_model_info`` derives + merges correctly on an OpenAI subclass. +# - Caller-supplied keys (e.g. langchain-openai's classic ``parallel_tool_calls``) +# survive the merge alongside the gateway-derived sampling set. +# - ``UiPathAzureChatOpenAI``'s native auto-init of +# ``{"parallel_tool_calls": None}`` (fires only when ``disabled_params is None``) +# still works when we have nothing to contribute (flag absent). + + +def test_openai_subclass_derives_disabled_params_from_model_details( + client_settings: UiPathBaseSettings, +) -> None: + llm = UiPathChatOpenAI( + model="some-reasoning-openai-model", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + assert llm.disabled_params is not None + assert set(llm.disabled_params) == set(DISABLED_SAMPLING_PARAMS) + + +def test_openai_subclass_merges_user_disabled_params_with_derived( + client_settings: UiPathBaseSettings, +) -> None: + # langchain-openai's classic disabled_params usage: disable parallel_tool_calls. + # Our setup_model_info should merge it with the gateway-derived sampling set. + llm = UiPathChatOpenAI( + model="some-reasoning-openai-model", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + disabled_params={"parallel_tool_calls": None}, + ) + assert llm.disabled_params is not None + assert set(llm.disabled_params) == set(DISABLED_SAMPLING_PARAMS) | {"parallel_tool_calls"} + + +def test_openai_subclass_user_override_wins_on_conflict( + client_settings: UiPathBaseSettings, +) -> None: + # If the caller narrows a derived key (e.g. "disable temperature only at 0.0"), + # their more specific spec must win over the unconditional None from the + # derivation. + llm = UiPathChatOpenAI( + model="some-reasoning-openai-model", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + disabled_params={"temperature": [0.0]}, + ) + assert llm.disabled_params is not None + assert llm.disabled_params["temperature"] == [0.0] + # Other sampling-set keys remain unconditionally disabled. + assert llm.disabled_params["top_p"] is None + + +def test_azure_autoinit_parallel_tool_calls_still_fires_without_flag( + client_settings: UiPathBaseSettings, +) -> None: + # AzureChatOpenAI auto-sets {"parallel_tool_calls": None} in its own + # model_validator when disabled_params is None and the model is not gpt-4o. + # With no shouldSkipTemperature, setup_model_info leaves disabled_params as + # None, so Azure's native logic must still fire. + llm = UiPathAzureChatOpenAI( + model="gpt-5.1", # not gpt-4o -> Azure auto-init applies + settings=client_settings, + model_details={}, + ) + assert llm.disabled_params == {"parallel_tool_calls": None} + + +def test_azure_autoinit_parallel_tool_calls_merges_with_our_derivation( + client_settings: UiPathBaseSettings, +) -> None: + # AzureChatOpenAI's own model_validator runs before ours in MRO order and + # sets ``disabled_params = {"parallel_tool_calls": None}`` (for non-gpt-4o + # models). Our setup_model_info then treats that as a caller-provided + # value and merges the derived sampling set on top. Result: BOTH Azure's + # classic parallel_tool_calls restriction AND the gateway's + # shouldSkipTemperature-derived sampling set end up in disabled_params — + # neither convention is lost. + llm = UiPathAzureChatOpenAI( + model="gpt-5.1", # not gpt-4o -> Azure auto-init applies + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + ) + assert llm.disabled_params is not None + assert set(llm.disabled_params) == set(DISABLED_SAMPLING_PARAMS) | {"parallel_tool_calls"} + + +def test_openai_subclass_runtime_strip_honors_merged_disabled_params( + monkeypatch: pytest.MonkeyPatch, client_settings: UiPathBaseSettings +) -> None: + # End-to-end: runtime invoke-time strip on an OpenAI subclass sees the + # merged disabled_params and drops both the derived sampling key AND the + # user-supplied parallel_tool_calls. + llm = UiPathChatOpenAI( + model="some-reasoning-openai-model", + settings=client_settings, + model_details={"shouldSkipTemperature": True}, + disabled_params={"parallel_tool_calls": None}, + ) + captured: dict[str, Any] = {} + + def _stub_uipath_generate( + messages: Any, stop: Any = None, run_manager: Any = None, **kwargs: Any + ) -> ChatResult: + captured.update(kwargs) + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="ok"))]) + + monkeypatch.setattr(llm, "_uipath_generate", _stub_uipath_generate) + + llm.invoke( + "hi", + temperature=0.3, # derived disable + parallel_tool_calls=True, # user-supplied disable + max_tokens=50, # unrelated, survives + ) + + assert "temperature" not in captured + assert "parallel_tool_calls" not in captured + assert captured.get("max_tokens") == 50