diff --git a/examples/ai_modinput_app/bin/agentic_weather.py b/examples/ai_modinput_app/bin/agentic_weather.py index 768688c2..54856c56 100644 --- a/examples/ai_modinput_app/bin/agentic_weather.py +++ b/examples/ai_modinput_app/bin/agentic_weather.py @@ -20,6 +20,8 @@ from _collections_abc import dict_items from typing import final, override +from splunklib.ai.messages import AIMessage, ContentBlock, TextBlock + # ! NOTE: This insert is only needed for splunk-sdk-python CI/CD to work. # ! Remove this if you're modifying this example locally. sys.path.insert(0, "/splunklib-deps") @@ -95,9 +97,9 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None: weather_events += list(reader) for weather_event in weather_events: - weather_event["human_readable"] = asyncio.run( - self.invoke_agent(weather_event) - ) + result = asyncio.run(self.invoke_agent(weather_event)) + weather_event["human_readable"] = self.parse_content(result) + logger.debug(f"{weather_event=}") event = Event( @@ -112,7 +114,7 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None: logger.debug(f"Finishing enrichment for {input_name} at {csv_file_path}") - async def invoke_agent(self, weather_event: dict[str, str | int]) -> str: + async def invoke_agent(self, weather_event: dict[str, str | int]) -> AIMessage: if not self.service: raise AssertionError("No Splunk connection available") @@ -127,7 +129,27 @@ async def invoke_agent(self, weather_event: dict[str, str | int]) -> str: data=weather_event, ) logger.debug(f"{response=}") - return response.final_message.content + return response.final_message + + def _parse_content_block(self, block: str | ContentBlock) -> str | None: + match block: + case TextBlock(): + return block.text + case str(): + return block + case _: + return None + + def parse_content(self, message: AIMessage) -> str: + """Parses the content from AIMessage and builds a single string our of it""" + if isinstance(message.content, str): + return message.content + + return " ".join( + parsed_block + for block in message.content + if (parsed_block := self._parse_content_block(block)) + ) if __name__ == "__main__": diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 76fa100b..08e67fa7 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -77,7 +77,9 @@ AgentResponse, AIMessage, BaseMessage, + ContentBlock, HumanMessage, + OpaqueBlock, OutputT, StructuredOutputCall, StructuredOutputMessage, @@ -87,6 +89,7 @@ SubagentStructuredResult, SubagentTextResult, SystemMessage, + TextBlock, ToolCall, ToolFailureResult, ToolMessage, @@ -951,7 +954,7 @@ async def awrap_tool_call( return LC_ToolMessage( name=_normalize_agent_name(call.name), tool_call_id=call.id, - content=content, + content=_map_content_to_langchain(content), status=status, artifact=sdk_result, ) @@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result( # This invariant is asserted via ModelResponse.__post_init__ assert len(resp.message.structured_output_calls) <= 1 - lc_message = LC_AIMessage(content=resp.message.content) + lc_message = LC_AIMessage( + content=_map_content_to_langchain(resp.message.content), + additional_kwargs=resp.message.extras or {}, + ) # This field can't be set via __init__() lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls] @@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc( name=name, tool_call_id=message.call_id, status=status, - content=content, + content=_map_content_to_langchain(content), artifact=artifact, ) @@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe ai_message = model_response structured_response = None + additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs) return ModelResponse( message=AIMessage( - content=ai_message.content.__str__(), + content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType] calls=[ _map_tool_call_from_langchain(tc) for tc in ai_message.tool_calls @@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe for tc in ai_message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) ], + extras=additional_kwargs, ), structured_output=structured_response, ) @@ -1433,7 +1441,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool: async def invoke_agent( message: HumanMessage, thread_id: str | None - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: result = await agent.invoke([message], thread_id=thread_id) if agent.output_schema: @@ -1452,13 +1463,19 @@ async def invoke_agent( async def _run( # pyright: ignore[reportRedeclaration] content: str, thread_id: str - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: return await invoke_agent(HumanMessage(content=content), thread_id) else: async def _run( # pyright: ignore[reportRedeclaration] content: str, - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: return await invoke_agent(HumanMessage(content=content), None) return StructuredTool.from_function( @@ -1471,7 +1488,10 @@ async def _run( # pyright: ignore[reportRedeclaration] async def invoke_agent_structured( content: BaseModel, thread_id: str | None - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: result = await agent.invoke_with_data( instructions="Follow the system prompt.", data=content.model_dump(), @@ -1492,7 +1512,10 @@ async def invoke_agent_structured( async def _run( **kwargs: Any, # noqa: ANN401 - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: content: BaseModel = kwargs["content"] thread_id: str = kwargs["thread_id"] return await invoke_agent_structured(content, thread_id) @@ -1512,7 +1535,10 @@ async def _run( async def _run( **kwargs: Any, # noqa: ANN401 - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: content = InputSchema(**kwargs) return await invoke_agent_structured(content, None) @@ -1564,11 +1590,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall: return LC_ToolCall(id=call.id, name=name, args=args) +def _map_content_from_langchain( + content: str | list[str | dict[str, Any]], +) -> str | list[str | ContentBlock]: + if isinstance(content, str): + return content + + result_content = [_map_content_block_from_langchain(b) for b in content] + + return result_content + + +def _map_content_block_from_langchain( + block: str | dict[str, Any], +) -> str | ContentBlock: + if isinstance(block, str): + return block + + match block.get("type"): + case "text": + return TextBlock( + text=block["text"], + extras=block.get("extras"), + ) + case _: + # NOTE: we return data we're not handling + # as opaque content blocks so they + # are preserved and sent back to the LLM + return OpaqueBlock(data=block) + + +def _map_content_to_langchain( + content: str | list[str | ContentBlock], +) -> str | list[str | dict[str, Any]]: + if isinstance(content, str): + return content + + result_content = [_map_content_block_to_langchain(b) for b in content] + + return result_content + + +def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]: + if isinstance(block, str): + return block + + match block: + case TextBlock(): + result: dict[str, Any] = {"type": "text", "text": block.text} + if block.extras: + result["extras"] = block.extras + return result + case OpaqueBlock(): + return block.data + + def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: match message: case LC_AIMessage(): return AIMessage( - content=message.content.__str__(), + content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType] calls=[ _map_tool_call_from_langchain(tc) for tc in message.tool_calls @@ -1583,6 +1664,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: for tc in message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) ], + extras=cast(dict[str, Any], message.additional_kwargs), ) case LC_HumanMessage(): return HumanMessage(content=message.content.__str__()) @@ -1597,7 +1679,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage: match message: case AIMessage(): - lc_message = LC_AIMessage(content=message.content) + lc_message = LC_AIMessage( + content=_map_content_to_langchain(message.content), + additional_kwargs=message.extras or {}, + ) # This field can't be set via constructor lc_message.tool_calls = [ _map_tool_call_to_langchain(c) for c in message.calls diff --git a/splunklib/ai/messages.py b/splunklib/ai/messages.py index 04db32b6..bdacfb93 100644 --- a/splunklib/ai/messages.py +++ b/splunklib/ai/messages.py @@ -21,6 +21,31 @@ from splunklib.ai.tools import ToolType +@dataclass(frozen=True) +class TextBlock: + """Plain text content block returned by a model.""" + + text: str + # TODO: should we have the id here as well? + # Provider-specific extras (e.g. Gemini thought signature on text blocks). + extras: dict[str, Any] | None = field(default=None) + + +@dataclass(frozen=True) +class OpaqueBlock: + """Content block of an unrecognized or unsupported type. + + The raw provider dict is preserved in `data` so it can be sent back + to the model unchanged on subsequent calls. + """ + + data: dict[str, Any] + + +# Type alias for all content block variants. +ContentBlock = TextBlock | OpaqueBlock + + @dataclass(frozen=True) class ToolCall: name: str @@ -85,12 +110,15 @@ class AIMessage(BaseMessage): """ role: Literal["assistant"] = field(default="assistant", init=False) - content: str + content: str | list[str | ContentBlock] calls: Sequence[ToolCall | SubagentCall] structured_output_calls: Sequence[StructuredOutputCall] = field( default_factory=tuple ) + # Backend-specific metadata (e.g. provider additional_kwargs) not + # representable in the standard fields. Opaque to callers. + extras: dict[str, Any] | None = field(default=None) @dataclass(frozen=True) @@ -120,7 +148,7 @@ class SubagentTextResult: Returned by subagent calls that don't have an output schema. """ - content: str + content: str | list[str | ContentBlock] @dataclass(frozen=True) diff --git a/tests/ai_testlib.py b/tests/ai_testlib.py index 631fd16f..7c77ba4e 100644 --- a/tests/ai_testlib.py +++ b/tests/ai_testlib.py @@ -1,4 +1,7 @@ from typing import override +from warnings import warn + +from splunklib.ai.messages import AIMessage, ContentBlock, SubagentTextResult, TextBlock from splunklib.ai.model import PredefinedModel from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model from tests.testlib import SDKTestCase @@ -18,6 +21,27 @@ def setUp(self) -> None: app.delete() self.restart_splunk() + def _parse_content_block(self, block: str | ContentBlock) -> str | None: + match block: + case TextBlock(): + return block.text + case str(): + return block + case _: + warn(f"Skipping OpaqueBlock when parsing the AIMessage.content") + return None + + def parse_content(self, message: AIMessage | SubagentTextResult) -> str: + """Parses the content from AIMessage and builds a single string our of it""" + if isinstance(message.content, str): + return message.content + + return " ".join( + parsed_block + for block in message.content + if (parsed_block := self._parse_content_block(block)) + ) + @property def test_llm_settings(self) -> TestLLMSettings: client_id: str = self.opts.kwargs["internal_ai_client_id"] diff --git a/tests/integration/ai/test_agent.py b/tests/integration/ai/test_agent.py index 1f9ea591..1d0b8ab8 100644 --- a/tests/integration/ai/test_agent.py +++ b/tests/integration/ai/test_agent.py @@ -64,7 +64,12 @@ async def test_agent_with_openai_round_trip(self): ] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) assert result.structured_output is None, ( "The structured output should not be populated" ) @@ -152,7 +157,7 @@ class Person(BaseModel): response = result.structured_output - last_message = result.final_message.content + last_message = self.parse_content(result.final_message) assert type(response) == Person, "Response is not of type Person" assert response.name != "", "Name field is empty" @@ -220,7 +225,7 @@ class NicknameGeneratorInput(BaseModel): ) assert subagent_message, "No subagent message found in response" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris-zilla" in response, "Agent did generate valid nickname" @pytest.mark.asyncio @@ -264,7 +269,7 @@ async def test_subagent_without_input_schema(self): assert first_ai_message.calls[0].args.lower() == "chris" assert first_ai_message.calls[0].thread_id is None, "unexpected thread_id" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris-zilla" in response, "Agent did generate valid nickname" @pytest.mark.asyncio @@ -305,7 +310,7 @@ class Person(BaseModel): ] ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris-zilla" in response, "Agent did generate valid nickname" # TODO: unskip the test once we switch to a better model diff --git a/tests/integration/ai/test_agent_mcp_tools.py b/tests/integration/ai/test_agent_mcp_tools.py index ac2dd880..ecbf4fde 100644 --- a/tests/integration/ai/test_agent_mcp_tools.py +++ b/tests/integration/ai/test_agent_mcp_tools.py @@ -91,7 +91,7 @@ async def test_tool_execution_structured_output(self) -> None: assert tool_message, "No tool message found in response" assert tool_message.name == "temperature", "Invalid tool name" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "31.5" in response, "Invalid LLM response" @patch( @@ -188,7 +188,7 @@ async def test_multiple_and_concurrent_tool_calls(self) -> None: ] ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "31.5" in response, "Invalid LLM response" assert "30.0" in response, "Invalid LLM response" assert "25.5" in response, "Invalid LLM response" @@ -205,7 +205,7 @@ async def test_multiple_and_concurrent_tool_calls(self) -> None: ) ] ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "28.5" in response, "Invalid LLM response" # Make sure MCP was alive during entire Agent lifetime. @@ -377,7 +377,7 @@ async def dispatch( assert tool_message, "No tool message found in response" assert tool_message.name == "temperature", "Invalid tool name" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "31.5" in response, "Invalid LLM response" assert trace_id == agent.trace_id @@ -417,7 +417,12 @@ async def test_remote_tools_mcp_app_unavailable(self) -> None: [HumanMessage(content="What is your name? Answer in one word")] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) assert "stefan" in response @patch( @@ -492,7 +497,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: assert type(tool_messages[0].result) is ToolFailureResult assert type(tool_messages[1].result) is ToolResult - response = result.final_message.content + response = self.parse_content(result.final_message) assert "31.5" in response, "Invalid LLM response" @patch( @@ -595,7 +600,7 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]: ) assert found_tool_message, "missing ToolMessage in agent response" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "31.5" in response, "Invalid LLM response" @patch( @@ -656,7 +661,7 @@ async def middleware( assert tool_message, "No tool message found in response" assert tool_message.name == "temperature", "Invalid tool name" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "22" in response, "Invalid LLM response" diff --git a/tests/integration/ai/test_anthropic_agent.py b/tests/integration/ai/test_anthropic_agent.py index eeed9349..3f88c58f 100644 --- a/tests/integration/ai/test_anthropic_agent.py +++ b/tests/integration/ai/test_anthropic_agent.py @@ -47,6 +47,11 @@ async def test_agent_with_anthropic_round_trip(self): [HumanMessage(content="What is your name? Answer in one word")] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) assert result.structured_output is None assert "stefan" in response diff --git a/tests/integration/ai/test_conversation_store.py b/tests/integration/ai/test_conversation_store.py index a5c10b34..b3d28038 100644 --- a/tests/integration/ai/test_conversation_store.py +++ b/tests/integration/ai/test_conversation_store.py @@ -12,8 +12,8 @@ # License for the specific language governing permissions and limitations # under the License. -from pydantic import BaseModel, Field import pytest +from pydantic import BaseModel, Field from splunklib.ai import Agent from splunklib.ai.conversation_store import InMemoryStore @@ -44,7 +44,7 @@ async def test_agent_does_not_remember_state_without_store(self) -> None: result = await agent.invoke([HumanMessage(content="What is my name?")]) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris" not in response, "Agent remembered the name" @@ -101,7 +101,7 @@ async def _agent_middleware( result = await agent.invoke([HumanMessage(content="What is my name?")]) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris" in response, "Agent did not remember the name" @@ -146,7 +146,7 @@ async def _agent_middleware( result = await agent.invoke([HumanMessage(content="What is my name?")]) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Mike" in response, "Agent did not remember the name" @@ -185,7 +185,7 @@ async def _model_middleware( [HumanMessage(content="What is my name?")], thread_id="2", ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Mike" not in response, ( "Agent remembered the name from a different thread_id" ) @@ -219,14 +219,14 @@ async def test_thread_id_in_constructor(self) -> None: [HumanMessage(content="What is my name?")], thread_id="2", ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Mike" in response, "Agent did not remember the name" # When thread_id not specified the one from the agent constructor is used. result = await agent.invoke( [HumanMessage(content="What is my name?")], ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Mike" in response, "Agent did not remember the name" # Now use the same conversation_store in a different agent with same thread_ids. @@ -242,21 +242,21 @@ async def test_thread_id_in_constructor(self) -> None: [HumanMessage(content="What is my name?")], thread_id="1", ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris" in response, "Agent did not remember the name" result = await agent.invoke( [HumanMessage(content="What is my name?")], thread_id="2", ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Mike" in response, "Agent did not remember the name" # When thread_id not specified the one from the agent constructor is used. result = await agent.invoke( [HumanMessage(content="What is my name?")], ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Mike" in response, "Agent did not remember the name" @@ -328,7 +328,7 @@ async def _model_middleware( assert isinstance(third_ai_msg.calls[0], SubagentCall) assert thread_id == third_ai_msg.calls[0].thread_id, "missing thread_id" - assert "chris" in resp.final_message.content.lower() + assert "chris" in self.parse_content(resp.final_message).lower() # TODO: unskip the test once we switch to a better model @pytest.mark.asyncio @@ -403,4 +403,4 @@ class MemoryAgentInput(BaseModel): assert isinstance(third_ai_msg.calls[0], SubagentCall) assert thread_id == third_ai_msg.calls[0].thread_id, "invalid thread_id" - assert "chris" in resp.final_message.content.lower() + assert "chris" in self.parse_content(resp.final_message).lower() diff --git a/tests/integration/ai/test_hooks.py b/tests/integration/ai/test_hooks.py index ad22a75b..45c34894 100644 --- a/tests/integration/ai/test_hooks.py +++ b/tests/integration/ai/test_hooks.py @@ -29,8 +29,14 @@ before_agent, before_model, ) -from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage -from splunklib.ai.middleware import AgentRequest, ModelMiddlewareHandler, ModelRequest, ModelResponse, model_middleware +from splunklib.ai.messages import AgentResponse, AIMessage, HumanMessage +from splunklib.ai.middleware import ( + AgentRequest, + ModelMiddlewareHandler, + ModelRequest, + ModelResponse, + model_middleware, +) from tests.ai_testlib import AITestCase @@ -62,7 +68,7 @@ def test_hook_after(resp: ModelResponse) -> None: nonlocal hook_calls hook_calls += 1 - response = resp.message.content.strip().lower().replace(".", "") + response = self.parse_content(resp.message).strip().lower().replace(".", "") assert "stefan" == response @after_model @@ -70,7 +76,7 @@ async def test_async_hook_after(resp: ModelResponse) -> None: nonlocal hook_calls hook_calls += 1 - response = resp.message.content.strip().lower().replace(".", "") + response = self.parse_content(resp.message).strip().lower().replace(".", "") assert "stefan" == response async with Agent( @@ -92,7 +98,12 @@ async def test_async_hook_after(resp: ModelResponse) -> None: ] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) assert "stefan" == response assert hook_calls == 4 @@ -159,7 +170,12 @@ async def after_async_agent_hook(resp: AgentResponse) -> None: ] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) assert '{"name":"stefan"}' == response assert hook_calls == 4 @@ -197,10 +213,12 @@ async def test_agent_loop_stop_conditions_conversation_limit(self) -> None: with pytest.raises( StepsLimitExceededException, match="Steps limit of 2 exceeded" ): - _ = await agent.invoke([ - HumanMessage(content="hi, my name is Chris"), - HumanMessage(content="What is my name?"), - ]) + _ = await agent.invoke( + [ + HumanMessage(content="hi, my name is Chris"), + HumanMessage(content="What is my name?"), + ] + ) @pytest.mark.asyncio async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer( @@ -220,13 +238,17 @@ async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer( with pytest.raises( StepsLimitExceededException, match="Steps limit of 2 exceeded" ): - _ = await agent.invoke([ - HumanMessage(content="What is my name?"), - HumanMessage(content="Are you sure?"), - ]) + _ = await agent.invoke( + [ + HumanMessage(content="What is my name?"), + HumanMessage(content="Are you sure?"), + ] + ) @pytest.mark.asyncio - async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes(self) -> None: + async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes( + self, + ) -> None: pytest.importorskip("langchain_openai") step_limit = StepLimitMiddleware(2) diff --git a/tests/integration/ai/test_middleware.py b/tests/integration/ai/test_middleware.py index d699bb5b..1278de33 100644 --- a/tests/integration/ai/test_middleware.py +++ b/tests/integration/ai/test_middleware.py @@ -95,7 +95,7 @@ async def test_middleware( [HumanMessage(content="What is the weather like today in Krakow?")] ) - response = res.final_message.content + response = self.parse_content(res.final_message) assert "31.5" in response assert middleware_called, "Middleware was not called" @@ -162,7 +162,7 @@ async def test_middleware( [HumanMessage(content="What is the weather like today in Krakow?")] ) - response = res.final_message.content + response = self.parse_content(res.final_message) assert "31.5" in response assert middleware_called, "Middleware was not called" @@ -200,7 +200,7 @@ async def test_middleware( [HumanMessage(content="What is the weather like today in Kraków?")] ) - response = res.final_message.content + response = self.parse_content(res.final_message) assert "0.5" in response, "Invalid response from LLM" tool_message = next( @@ -253,7 +253,7 @@ async def second_middleware( res = await agent.invoke( [HumanMessage(content="What is the weather like today in Krakow?")] ) - assert "31.5" in res.final_message.content + assert "31.5" in self.parse_content(res.final_message) assert first_called, "First middleware was called after the second" assert second_called, "Second middleware was called before the first" @@ -295,7 +295,7 @@ async def model_test_middleware( res = await agent.invoke( [HumanMessage(content="What is the weather like today in Krakow?")] ) - assert "31.5" in res.final_message.content + assert "31.5" in self.parse_content(res.final_message) assert tool_called assert model_called @@ -349,7 +349,7 @@ async def subagent_middleware( tool_result = await agent.invoke( [HumanMessage(content="What is the weather like today in Krakow?")] ) - assert "31.5" in tool_result.final_message.content + assert "31.5" in self.parse_content(tool_result.final_message) class NicknameGeneratorInput(BaseModel): name: str = Field(description="The person's full name", min_length=1) @@ -377,7 +377,7 @@ class NicknameGeneratorInput(BaseModel): subagent_result = await supervisor.invoke( [HumanMessage(content="Generate a nickname for Chris")] ) - assert "Chris-zilla" in subagent_result.final_message.content + assert "Chris-zilla" in self.parse_content(subagent_result.final_message) assert model_called assert tool_called @@ -442,7 +442,7 @@ async def test_middleware( ) assert subagent_message, "No subagent message found in response" - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris-zilla" in response, "Agent did generate valid nickname" assert middleware_called, "Middleware was not called" @@ -492,7 +492,7 @@ async def test_middleware( [HumanMessage(content="Generate a nickname for Chris")] ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris-superstar" in response, "Invalid response from LLM" subagent_message = next( @@ -500,7 +500,7 @@ async def test_middleware( ) assert subagent_message, "SubagentMessage not found in messages" assert isinstance(subagent_message.result, SubagentTextResult) - assert subagent_message.result.content == "Chris-superstar", ( + assert self.parse_content(subagent_message.result) == "Chris-superstar", ( "Invalid response from subagent" ) assert middleware_called, "Middleware was not called" @@ -621,7 +621,7 @@ async def test_middleware( [HumanMessage(content="Generate a nickname for Chris")] ) - response = result.final_message.content + response = self.parse_content(result.final_message) assert "Chris-zilla" in response, "Agent did generate valid nickname" assert middleware_called, "Middleware was not called" @@ -657,7 +657,7 @@ async def test_middleware( ] ) - response = res.final_message.content + response = self.parse_content(res.final_message) assert "My response is made up" == response assert middleware_called, "Middleware was not called" @@ -715,7 +715,7 @@ async def mutating_middleware( res = await agent.invoke( [HumanMessage(content="What is the capital of Germany?")] ) - assert "Paris" in res.final_message.content + assert "Paris" in self.parse_content(res.final_message) @patch( "splunklib.ai.agent._testing_local_tools_path", @@ -752,7 +752,7 @@ async def mutating_middleware( [HumanMessage(content="What is the weather like today in Berlin?")] ) # Berlin returns 22.1C; Krakow returns 31.5C - assert "31.5" in res.final_message.content + assert "31.5" in self.parse_content(res.final_message) @pytest.mark.asyncio async def test_subagent_middleware_arg_mutation_reaches_subagent(self) -> None: @@ -796,7 +796,7 @@ async def mutating_middleware( result = await supervisor.invoke( [HumanMessage(content="Generate a nickname for Bob")] ) - assert "Alice-zilla" in result.final_message.content + assert "Alice-zilla" in self.parse_content(result.final_message) @pytest.mark.asyncio async def test_model_middleware_structured_output(self) -> None: diff --git a/tests/integration/ai/test_structured_output.py b/tests/integration/ai/test_structured_output.py index 9f793f95..67f4c572 100644 --- a/tests/integration/ai/test_structured_output.py +++ b/tests/integration/ai/test_structured_output.py @@ -302,7 +302,7 @@ async def _model_middleware( ) try: - Person.model_validate_json(e.message.content) + Person.model_validate_json(self.parse_content(e.message)) raise AssertionError( "args are valid, but got an StructuredOutputGenerationException" ) @@ -314,7 +314,7 @@ async def _model_middleware( assert after_first_model_call, "generation error did not happen" assert resp.structured_output is not None, "missing structured_output" assert ( - Person.model_validate_json(resp.message.content) + Person.model_validate_json(self.parse_content(resp.message)) == resp.structured_output ), "invalid structured output" @@ -345,7 +345,7 @@ async def _model_middleware( assert len(result.final_message.structured_output_calls) == 0 assert ( - Person.model_validate_json(result.final_message.content) + Person.model_validate_json(self.parse_content(result.final_message)) == result.structured_output ) @@ -827,7 +827,9 @@ async def _model_middleware( assert "ALL letters must be capitalized" in e.error.validation_error assert len(e.message.structured_output_calls) == 0 - args = PersonNotRestricted.model_validate_json(e.message.content) + args = PersonNotRestricted.model_validate_json( + self.parse_content(e.message) + ) args.name = args.name.upper() return ModelResponse( diff --git a/tests/system/test_apps/ai_agentic_test_app/bin/agentic_endpoint.py b/tests/system/test_apps/ai_agentic_test_app/bin/agentic_endpoint.py index 2b02b387..a81f1ff0 100644 --- a/tests/system/test_apps/ai_agentic_test_app/bin/agentic_endpoint.py +++ b/tests/system/test_apps/ai_agentic_test_app/bin/agentic_endpoint.py @@ -21,7 +21,13 @@ from typing import override from splunklib.ai.agent import Agent -from splunklib.ai.messages import HumanMessage +from splunklib.ai.messages import ( + AIMessage, + ContentBlock, + HumanMessage, + SubagentTextResult, + TextBlock, +) from splunklib.ai.tool_settings import ToolSettings from tests.cre_testlib import CRETestHandler @@ -58,5 +64,30 @@ async def run(self) -> None: [HumanMessage(content="What is your name? Answer in one word")] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) self.response.write(response) + + def _parse_content_block(self, block: str | ContentBlock) -> str | None: + match block: + case TextBlock(): + return block.text + case str(): + return block + case _: + return None + + def parse_content(self, message: AIMessage | SubagentTextResult) -> str: + """Parses the content from AIMessage and builds a single string our of it""" + if isinstance(message.content, str): + return message.content + + return " ".join( + parsed_block + for block in message.content + if (parsed_block := self._parse_content_block(block)) + ) diff --git a/tests/system/test_apps/ai_agentic_test_local_tools_app/bin/agentic_app_tools_endpoint.py b/tests/system/test_apps/ai_agentic_test_local_tools_app/bin/agentic_app_tools_endpoint.py index 044233e6..80928dc8 100644 --- a/tests/system/test_apps/ai_agentic_test_local_tools_app/bin/agentic_app_tools_endpoint.py +++ b/tests/system/test_apps/ai_agentic_test_local_tools_app/bin/agentic_app_tools_endpoint.py @@ -22,7 +22,13 @@ from typing import override from splunklib.ai.agent import Agent -from splunklib.ai.messages import HumanMessage +from splunklib.ai.messages import ( + AIMessage, + ContentBlock, + HumanMessage, + SubagentTextResult, + TextBlock, +) from splunklib.ai.tool_settings import ToolSettings from tests.cre_testlib import CRETestHandler @@ -75,5 +81,30 @@ async def run(self) -> None: [HumanMessage(content="What is your name? Answer in one word")] ) - response = result.final_message.content.strip().lower().replace(".", "") + response = ( + self.parse_content(result.final_message) + .strip() + .lower() + .replace(".", "") + ) self.response.write(response) + + def _parse_content_block(self, block: str | ContentBlock) -> str | None: + match block: + case TextBlock(): + return block.text + case str(): + return block + case _: + return None + + def parse_content(self, message: AIMessage | SubagentTextResult) -> str: + """Parses the content from AIMessage and builds a single string our of it""" + if isinstance(message.content, str): + return message.content + + return " ".join( + parsed_block + for block in message.content + if (parsed_block := self._parse_content_block(block)) + ) diff --git a/tests/unit/ai/engine/test_langchain_backend.py b/tests/unit/ai/engine/test_langchain_backend.py index c02426bd..d4319896 100644 --- a/tests/unit/ai/engine/test_langchain_backend.py +++ b/tests/unit/ai/engine/test_langchain_backend.py @@ -30,10 +30,12 @@ from splunklib.ai.messages import ( AIMessage, HumanMessage, + OpaqueBlock, SubagentCall, SubagentFailureResult, SubagentMessage, SystemMessage, + TextBlock, ToolCall, ToolFailureResult, ToolMessage, @@ -56,6 +58,95 @@ def test_map_message_from_langchain_ai_with_tool_calls(self) -> None: ToolCall(name="lookup", args={"q": "test"}, id="tc-1", type=ToolType.REMOTE) ] + def test_map_message_from_langchain_ai_with_text_content_block(self) -> None: + text_block = { + "type": "text", + "text": "test-content-block", + "extras": { + # simulate gemini model returning thought signature in extra field of text content block + "signature": "EjQKMgEMOdbHDmsQ+BTM6duYJ43i5npxkpn28Ir0VjD1p6w4fUqIdYszIcWx+XcqAW1a8E+Q" + }, + } + message = LC_AIMessage(content=[text_block], tool_calls=[]) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert isinstance(mapped.content[0], TextBlock) + assert mapped.content[0].text == "test-content-block" + + def test_map_message_from_langchain_ai_with_list_of_str(self) -> None: + message = LC_AIMessage(content=["one", "two"], tool_calls=[]) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert mapped.content == ["one", "two"] + + def test_map_message_from_langchain_ai_with_other_content_block(self) -> None: + content_block = { + "type": "image", + } + message = LC_AIMessage(content=[content_block], tool_calls=[]) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert isinstance(mapped.content[0], OpaqueBlock) + assert mapped.content[0].data == content_block + + def test_map_message_from_langchain_ai_with_mixed_content(self) -> None: + content_block = { + "type": "image", + } + text_block = { + "type": "text", + "text": "test", + } + message = LC_AIMessage( + content=[content_block, text_block, "test"], tool_calls=[] + ) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert isinstance(mapped.content[0], OpaqueBlock) + assert mapped.content[0].data == content_block + assert isinstance(mapped.content[1], TextBlock) + assert mapped.content[1].text == "test" + assert mapped.content[2] == "test" + + def test_map_message_from_langchain_ai_tool_call_with_additional_kwargs( + self, + ) -> None: + tool_call = LC_ToolCall( + name=f"__local-startup_time", + args={"q": "test"}, + id="tc-2", + ) + # simulate gemini models returning thought signature in additional kwargs + # when calling tools. + additional_kwargs = { + "function_call": {"name": "__local-startup_time", "arguments": "{}"}, + "__gemini_function_call_thought_signatures__": { + "28e28045-9846-4c9c-ab46-97f33bff5a9c": "EjQKMgEMOdbHH9gTl8BkX2uMM52753GCboanCcnUp9XB896IdThnG42GB8lRSkqGGxVbv5JY" + }, + } + message = LC_AIMessage( + content="done", tool_calls=[tool_call], additional_kwargs=additional_kwargs + ) + mapped = lc._map_message_from_langchain(message) + assert isinstance(mapped, AIMessage) + assert mapped.calls == [ + ToolCall( + name="startup_time", + args={"q": "test"}, + id="tc-2", + type=ToolType.LOCAL, + ) + ] + assert mapped.extras == additional_kwargs + def test_map_message_from_langchain_ai_with_agent_call(self) -> None: tool_call = LC_ToolCall( name=f"{lc.AGENT_PREFIX}assistant", @@ -159,6 +250,69 @@ def test_map_message_to_langchain_ai(self) -> None: assert mapped.content == "hi" assert mapped.tool_calls == [LC_ToolCall(name="lookup", args={}, id="tc-1")] + def test_map_message_to_langchain_ai_with_text_content_block(self) -> None: + extras = { + "signature": "EjQKMgEMOdbHDmsQ+BTM6duYJ43i5npxkpn28Ir0VjD1p6w4fUqIdYszIcWx+XcqAW1a8E+Q" + } + message = AIMessage( + content=[ + TextBlock( + text="test-content-block", + extras=extras, + ) + ], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert isinstance(mapped.content[0], dict) + assert mapped.content[0]["type"] == "text" + assert mapped.content[0]["text"] == "test-content-block" + assert mapped.content[0]["extras"] == extras + + def test_map_message_to_langchain_ai_with_list_of_str(self) -> None: + message = AIMessage( + content=["one", "two"], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert mapped.content == ["one", "two"] + + def test_map_message_to_langchain_ai_with_opaque_content_block(self) -> None: + some_data = {"type": "unsupported"} + message = AIMessage( + content=[OpaqueBlock(data=some_data)], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert isinstance(mapped.content[0], dict) + assert mapped.content[0]["type"] == "unsupported" + + def test_map_message_to_langchain_ai_with_mixed_content_block(self) -> None: + some_data = {"type": "unsupported"} + message = AIMessage( + content=[ + OpaqueBlock(data=some_data), + TextBlock(text="test-content-block"), + "test", + ], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert isinstance(mapped.content[0], dict) + assert mapped.content[0]["type"] == "unsupported" + assert isinstance(mapped.content[1], dict) + assert mapped.content[1]["type"] == "text" + assert mapped.content[1]["text"] == "test-content-block" + assert mapped.content[2] == "test" + def test_map_message_to_langchain_ai_with_agent_call(self) -> None: message = AIMessage( content="hi", @@ -182,6 +336,42 @@ def test_map_message_to_langchain_ai_with_agent_call(self) -> None: ) ] + def test_map_message_to_langchain_ai_with_tool_call_with_thought_signature( + self, + ) -> None: + extras = { + "function_call": { + "name": "__local-startup_time", + "arguments": '{"q": "test"}', + }, + "__gemini_function_call_thought_signatures__": { + "28e28045-9846-4c9c-ab46-97f33bff5a9c": "EjQKMgEMOdbHH9gTl8BkX2uMM52753GCboanCcnUp9XB896IdThnG42GB8lRSkqGGxVbv5JY" + }, + } + message = AIMessage( + content="hi", + calls=[ + ToolCall( + name="startup_time", + args={"q": "test"}, + id="tc-2", + type=ToolType.LOCAL, + ) + ], + extras=extras, + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert mapped.tool_calls == [ + LC_ToolCall( + name=f"__local-startup_time", + args={"q": "test"}, + id="tc-2", + ) + ] + assert mapped.additional_kwargs == extras + def test_map_message_to_langchain_human(self) -> None: message = HumanMessage(content="hello") mapped = lc._map_message_to_langchain(message)