Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions examples/ai_modinput_app/bin/agentic_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from _collections_abc import dict_items
from typing import final, override

from splunklib.ai.messages import AIMessage, ContentBlock, TextBlock

# ! NOTE: This insert is only needed for splunk-sdk-python CI/CD to work.
# ! Remove this if you're modifying this example locally.
sys.path.insert(0, "/splunklib-deps")
Expand Down Expand Up @@ -95,9 +97,9 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:
weather_events += list(reader)

for weather_event in weather_events:
weather_event["human_readable"] = asyncio.run(
self.invoke_agent(weather_event)
)
result = asyncio.run(self.invoke_agent(weather_event))
weather_event["human_readable"] = self.parse_content(result)

logger.debug(f"{weather_event=}")

event = Event(
Expand All @@ -112,7 +114,7 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:

logger.debug(f"Finishing enrichment for {input_name} at {csv_file_path}")

async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
async def invoke_agent(self, weather_event: dict[str, str | int]) -> AIMessage:
if not self.service:
raise AssertionError("No Splunk connection available")

Expand All @@ -127,7 +129,27 @@ async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
data=weather_event,
)
logger.debug(f"{response=}")
return response.final_message.content
return response.final_message

def _parse_content_block(self, block: str | ContentBlock) -> str | None:
match block:
case TextBlock():
return block.text
case str():
return block
case _:
return None

def parse_content(self, message: AIMessage) -> str:
"""Parses the content from AIMessage and builds a single string our of it"""
if isinstance(message.content, str):
return message.content

return " ".join(
parsed_block
for block in message.content
if (parsed_block := self._parse_content_block(block))
)


if __name__ == "__main__":
Expand Down
109 changes: 97 additions & 12 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
AgentResponse,
AIMessage,
BaseMessage,
ContentBlock,
HumanMessage,
OpaqueBlock,
OutputT,
StructuredOutputCall,
StructuredOutputMessage,
Expand All @@ -87,6 +89,7 @@
SubagentStructuredResult,
SubagentTextResult,
SystemMessage,
TextBlock,
ToolCall,
ToolFailureResult,
ToolMessage,
Expand Down Expand Up @@ -951,7 +954,7 @@ async def awrap_tool_call(
return LC_ToolMessage(
name=_normalize_agent_name(call.name),
tool_call_id=call.id,
content=content,
content=_map_content_to_langchain(content),
status=status,
artifact=sdk_result,
)
Expand Down Expand Up @@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result(
# This invariant is asserted via ModelResponse.__post_init__
assert len(resp.message.structured_output_calls) <= 1

lc_message = LC_AIMessage(content=resp.message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(resp.message.content),
additional_kwargs=resp.message.extras or {},
)
# This field can't be set via __init__()
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]

Expand Down Expand Up @@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc(
name=name,
tool_call_id=message.call_id,
status=status,
content=content,
content=_map_content_to_langchain(content),
artifact=artifact,
)

Expand Down Expand Up @@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
ai_message = model_response
structured_response = None

additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs)
return ModelResponse(
message=AIMessage(
content=ai_message.content.__str__(),
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in ai_message.tool_calls
Expand All @@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
for tc in ai_message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=additional_kwargs,
),
structured_output=structured_response,
)
Expand Down Expand Up @@ -1433,7 +1441,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:

async def invoke_agent(
message: HumanMessage, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke([message], thread_id=thread_id)

if agent.output_schema:
Expand All @@ -1452,13 +1463,19 @@ async def invoke_agent(

async def _run( # pyright: ignore[reportRedeclaration]
content: str, thread_id: str
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), thread_id)
else:

async def _run( # pyright: ignore[reportRedeclaration]
content: str,
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), None)

return StructuredTool.from_function(
Expand All @@ -1471,7 +1488,10 @@ async def _run( # pyright: ignore[reportRedeclaration]

async def invoke_agent_structured(
content: BaseModel, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke_with_data(
instructions="Follow the system prompt.",
data=content.model_dump(),
Expand All @@ -1492,7 +1512,10 @@ async def invoke_agent_structured(

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
content: BaseModel = kwargs["content"]
thread_id: str = kwargs["thread_id"]
return await invoke_agent_structured(content, thread_id)
Expand All @@ -1512,7 +1535,10 @@ async def _run(

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
content = InputSchema(**kwargs)
return await invoke_agent_structured(content, None)

Expand Down Expand Up @@ -1564,11 +1590,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
return LC_ToolCall(id=call.id, name=name, args=args)


def _map_content_from_langchain(
content: str | list[str | dict[str, Any]],
) -> str | list[str | ContentBlock]:
if isinstance(content, str):
return content

result_content = [_map_content_block_from_langchain(b) for b in content]

return result_content


def _map_content_block_from_langchain(
block: str | dict[str, Any],
) -> str | ContentBlock:
if isinstance(block, str):
return block

match block.get("type"):
case "text":
return TextBlock(
text=block["text"],
extras=block.get("extras"),
)
case _:
# NOTE: we return data we're not handling
# as opaque content blocks so they
# are preserved and sent back to the LLM
return OpaqueBlock(data=block)


def _map_content_to_langchain(
content: str | list[str | ContentBlock],
) -> str | list[str | dict[str, Any]]:
if isinstance(content, str):
return content

result_content = [_map_content_block_to_langchain(b) for b in content]

return result_content


def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
if isinstance(block, str):
return block

match block:
case TextBlock():
result: dict[str, Any] = {"type": "text", "text": block.text}
if block.extras:
result["extras"] = block.extras
return result
case OpaqueBlock():
return block.data


def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
match message:
case LC_AIMessage():
return AIMessage(
content=message.content.__str__(),
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in message.tool_calls
Expand All @@ -1583,6 +1664,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
for tc in message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=cast(dict[str, Any], message.additional_kwargs),
)
case LC_HumanMessage():
return HumanMessage(content=message.content.__str__())
Expand All @@ -1597,7 +1679,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
match message:
case AIMessage():
lc_message = LC_AIMessage(content=message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(message.content),
additional_kwargs=message.extras or {},
)
# This field can't be set via constructor
lc_message.tool_calls = [
_map_tool_call_to_langchain(c) for c in message.calls
Expand Down
32 changes: 30 additions & 2 deletions splunklib/ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@
from splunklib.ai.tools import ToolType


@dataclass(frozen=True)
class TextBlock:
"""Plain text content block returned by a model."""

text: str
# TODO: should we have the id here as well?
# Provider-specific extras (e.g. Gemini thought signature on text blocks).
extras: dict[str, Any] | None = field(default=None)


@dataclass(frozen=True)
class OpaqueBlock:
"""Content block of an unrecognized or unsupported type.

The raw provider dict is preserved in `data` so it can be sent back
to the model unchanged on subsequent calls.
"""

data: dict[str, Any]


# Type alias for all content block variants.
ContentBlock = TextBlock | OpaqueBlock


@dataclass(frozen=True)
class ToolCall:
name: str
Expand Down Expand Up @@ -85,12 +110,15 @@ class AIMessage(BaseMessage):
"""

role: Literal["assistant"] = field(default="assistant", init=False)
content: str
content: str | list[str | ContentBlock]

calls: Sequence[ToolCall | SubagentCall]
structured_output_calls: Sequence[StructuredOutputCall] = field(
default_factory=tuple
)
# Backend-specific metadata (e.g. provider additional_kwargs) not
# representable in the standard fields. Opaque to callers.
extras: dict[str, Any] | None = field(default=None)


@dataclass(frozen=True)
Expand Down Expand Up @@ -120,7 +148,7 @@ class SubagentTextResult:
Returned by subagent calls that don't have an output schema.
"""

content: str
content: str | list[str | ContentBlock]


@dataclass(frozen=True)
Expand Down
24 changes: 24 additions & 0 deletions tests/ai_testlib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import override
from warnings import warn

from splunklib.ai.messages import AIMessage, ContentBlock, SubagentTextResult, TextBlock
from splunklib.ai.model import PredefinedModel
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
from tests.testlib import SDKTestCase
Expand All @@ -18,6 +21,27 @@ def setUp(self) -> None:
app.delete()
self.restart_splunk()

def _parse_content_block(self, block: str | ContentBlock) -> str | None:
match block:
case TextBlock():
return block.text
case str():
return block
case _:
warn(f"Skipping OpaqueBlock when parsing the AIMessage.content")
return None

def parse_content(self, message: AIMessage | SubagentTextResult) -> str:
"""Parses the content from AIMessage and builds a single string our of it"""
if isinstance(message.content, str):
return message.content

return " ".join(
parsed_block
for block in message.content
if (parsed_block := self._parse_content_block(block))
)

@property
def test_llm_settings(self) -> TestLLMSettings:
client_id: str = self.opts.kwargs["internal_ai_client_id"]
Expand Down
Loading