diff --git a/temporalio/contrib/google_adk_agents/_model.py b/temporalio/contrib/google_adk_agents/_model.py index 8b32a7432..1992d0f4c 100644 --- a/temporalio/contrib/google_adk_agents/_model.py +++ b/temporalio/contrib/google_adk_agents/_model.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator, Callable +from dataclasses import dataclass from datetime import timedelta from google.adk.models import BaseLlm, LLMRegistry @@ -7,6 +8,8 @@ import temporalio.workflow from temporalio import activity, workflow +from temporalio.contrib.workflow_streams import WorkflowStreamClient +from temporalio.exceptions import ApplicationError from temporalio.workflow import ActivityConfig @@ -36,6 +39,58 @@ async def invoke_model(llm_request: LlmRequest) -> list[LlmResponse]: ] +@dataclass +class StreamingInvokeInput: + """Input for :func:`invoke_model_streaming`.""" + + llm_request: LlmRequest + streaming_topic: str + streaming_batch_interval: timedelta + + +@activity.defn +async def invoke_model_streaming( + input: StreamingInvokeInput, +) -> list[LlmResponse]: + """Streaming-aware model activity. + + .. warning:: + Streaming support is experimental and may change in future + versions. + + Calls the LLM with ``stream=True`` and returns the collected list of + raw ``LlmResponse`` chunks. The workflow's ``TemporalModel.generate_content_async`` + yields these to the caller. + + Each response is also published to the workflow's stream on + ``streaming_topic`` so external consumers (UIs, tracing, etc.) + can observe responses as they arrive. + """ + llm_request = input.llm_request + if llm_request.model is None: + raise ValueError("No model name provided, could not create LLM.") + + llm = LLMRegistry.new_llm(llm_request.model) + if not llm: + raise ValueError(f"Failed to create LLM for model: {llm_request.model}") + + responses: list[LlmResponse] = [] + + stream = WorkflowStreamClient.from_within_activity( + batch_interval=input.streaming_batch_interval, + ) + events = stream.topic(input.streaming_topic, type=LlmResponse) + async with stream: + async for response in llm.generate_content_async( + llm_request=llm_request, stream=True + ): + activity.heartbeat() + responses.append(response) + events.publish(response) + + return responses + + class TemporalModel(BaseLlm): """A Temporal-based LLM model that executes model invocations as activities.""" @@ -45,9 +100,15 @@ def __init__( activity_config: ActivityConfig | None = None, *, summary_fn: Callable[[LlmRequest], str | None] | None = None, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), ) -> None: """Initialize the TemporalModel. + Streaming is selected by the caller via the ADK + ``generate_content_async(stream=True)`` argument; no plugin-level + flag is needed. + Args: model_name: The name of the model to use. activity_config: Configuration options for the activity execution. @@ -56,6 +117,19 @@ def __init__( deterministic as it is called during workflow execution. If the callable raises, the exception will propagate and fail the workflow task. + streaming_topic: Stream topic to publish raw + ``LlmResponse`` chunks to when streaming. Required when + callers invoke ``generate_content_async(stream=True)``; + if ``None``, the streaming call raises before scheduling + an activity. The workflow must host a + :class:`temporalio.contrib.workflow_streams.WorkflowStream` + to receive the publishes; otherwise the signals are + unhandled and dropped. Streaming support is + experimental and may change in future versions. + streaming_batch_interval: Interval between automatic + flushes for the stream publisher used by the streaming + activity. Streaming support is experimental and may + change in future versions. Raises: ValueError: If both ``ActivityConfig["summary"]`` and ``summary_fn`` are set. @@ -63,6 +137,8 @@ def __init__( super().__init__(model=model_name) self._model_name = model_name self._summary_fn = summary_fn + self._streaming_topic = streaming_topic + self._streaming_batch_interval = streaming_batch_interval self._activity_config = ActivityConfig( start_to_close_timeout=timedelta(seconds=60) ) @@ -80,7 +156,10 @@ async def generate_content_async( Args: llm_request: The LLM request containing model parameters and content. - stream: Whether to stream the response (currently ignored). + stream: Whether to use the streaming activity. When ``True``, + each chunk is also published to ``streaming_topic`` + (if set) for external consumers. Streaming support is + experimental and may change in future versions. Yields: The responses from the model. @@ -103,10 +182,28 @@ async def generate_content_async( agent_name = llm_request.config.labels.get("adk_agent_name") if agent_name: config["summary"] = agent_name - responses = await workflow.execute_activity( - invoke_model, - args=[llm_request], - **config, - ) + + if stream: + if self._streaming_topic is None: + raise ApplicationError( + "generate_content_async(stream=True) requires " + "TemporalModel(streaming_topic=...) to be set.", + non_retryable=True, + ) + responses = await workflow.execute_activity( + invoke_model_streaming, + StreamingInvokeInput( + llm_request=llm_request, + streaming_topic=self._streaming_topic, + streaming_batch_interval=self._streaming_batch_interval, + ), + **config, + ) + else: + responses = await workflow.execute_activity( + invoke_model, + args=[llm_request], + **config, + ) for response in responses: yield response diff --git a/temporalio/contrib/google_adk_agents/_plugin.py b/temporalio/contrib/google_adk_agents/_plugin.py index 9be321398..7344485c8 100644 --- a/temporalio/contrib/google_adk_agents/_plugin.py +++ b/temporalio/contrib/google_adk_agents/_plugin.py @@ -3,12 +3,16 @@ import dataclasses import time import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager +from typing import Any from temporalio import workflow from temporalio.contrib.google_adk_agents._mcp import TemporalMcpToolSetProvider -from temporalio.contrib.google_adk_agents._model import invoke_model +from temporalio.contrib.google_adk_agents._model import ( + invoke_model, + invoke_model_streaming, +) from temporalio.contrib.pydantic import ( PydanticPayloadConverter, ToJsonOptions, @@ -95,7 +99,13 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: ) return runner - new_activities = [invoke_model] + # Annotate as Sequence[Callable[..., Any]] because invoke_model + # and invoke_model_streaming have different signatures, so the + # inferred list type would not satisfy SimplePlugin's parameter. + new_activities: list[Callable[..., Any]] = [ + invoke_model, + invoke_model_streaming, + ] if toolset_providers is not None: for toolset_provider in toolset_providers: new_activities.extend(toolset_provider._get_activities()) diff --git a/tests/contrib/google_adk_agents/test_adk_streaming.py b/tests/contrib/google_adk_agents/test_adk_streaming.py new file mode 100644 index 000000000..30aecd9f4 --- /dev/null +++ b/tests/contrib/google_adk_agents/test_adk_streaming.py @@ -0,0 +1,195 @@ +"""Integration tests for ADK streaming support. + +Verifies that the streaming model activity publishes raw ``LlmResponse`` +chunks via the WorkflowStream broker. Non-streaming behavior is covered +by ``test_google_adk_agents.py``. +""" + +import asyncio +import uuid +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from google.adk import Agent +from google.adk.agents.run_config import RunConfig, StreamingMode +from google.adk.models import BaseLlm, LLMRegistry +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.runners import InMemoryRunner +from google.genai.types import Content, Part + +from temporalio import workflow +from temporalio.client import Client, WorkflowFailureError +from temporalio.contrib.google_adk_agents import GoogleAdkPlugin, TemporalModel +from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient +from temporalio.worker import Worker + + +class StreamingTestModel(BaseLlm): + """Test model that yields multiple partial responses to simulate streaming.""" + + @classmethod + def supported_models(cls) -> list[str]: + return ["streaming_test_model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # The streaming activity must call us with stream=True; if a + # regression drops the flag this test should fail. + if not stream: + raise AssertionError( + "StreamingTestModel.generate_content_async requires stream=True" + ) + yield LlmResponse(content=Content(role="model", parts=[Part(text="Hello ")])) + yield LlmResponse(content=Content(role="model", parts=[Part(text="world!")])) + + +@workflow.defn +class StreamingAdkWorkflow: + """Test workflow that opts into streaming via RunConfig.streaming_mode.""" + + @workflow.init + def __init__(self, prompt: str) -> None: + self.stream = WorkflowStream() + + @workflow.run + async def run(self, prompt: str) -> str: + model = TemporalModel("streaming_test_model", streaming_topic="events") + agent = Agent( + name="test_agent", + model=model, + instruction="You are a test agent.", + ) + + runner = InMemoryRunner(agent=agent, app_name="test-app") + session = await runner.session_service.create_session( + app_name="test-app", user_id="test" + ) + + final_text = "" + async for event in runner.run_async( + user_id="test", + session_id=session.id, + new_message=Content(role="user", parts=[Part(text=prompt)]), + run_config=RunConfig(streaming_mode=StreamingMode.SSE), + ): + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + final_text = part.text + + return final_text + + +@pytest.mark.asyncio +async def test_streaming_publishes_events(client: Client): + """Streaming activity publishes raw LlmResponse chunks to the topic.""" + LLMRegistry.register(StreamingTestModel) + + new_config = client.config() + new_config["plugins"] = [GoogleAdkPlugin()] + client = Client(**new_config) + + workflow_id = f"adk-streaming-test-{uuid.uuid4()}" + + async with Worker( + client, + task_queue="adk-streaming-test", + workflows=[StreamingAdkWorkflow], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + StreamingAdkWorkflow.run, + "Hello", + id=workflow_id, + task_queue="adk-streaming-test", + execution_timeout=timedelta(seconds=30), + ) + + stream = WorkflowStreamClient.create(client, workflow_id) + responses: list[LlmResponse] = [] + + async def collect_events() -> None: + async for item in stream.subscribe( + ["events"], + from_offset=0, + result_type=LlmResponse, + poll_cooldown=timedelta(milliseconds=50), + ): + responses.append(item.data) + if len(responses) >= 2: + break + + collect_task = asyncio.create_task(collect_events()) + result = await handle.result() + await asyncio.wait_for(collect_task, timeout=10.0) + + # Workflow assembles streamed parts; the last part it observes is "world!". + assert result == "world!" + + texts: list[str] = [] + for r in responses: + if r.content and r.content.parts: + for part in r.content.parts: + if part.text: + texts.append(part.text) + assert texts == ["Hello ", "world!"], f"Unexpected text deltas: {texts}" + + +@workflow.defn +class StreamingAdkRequiresTopicWorkflow: + """Calls ``generate_content_async(stream=True)`` without configuring + ``streaming_topic``; the call must raise before any activity + is scheduled.""" + + @workflow.run + async def run(self, prompt: str) -> str: + model = TemporalModel("streaming_test_model") + agent = Agent( + name="test_agent", + model=model, + instruction="You are a test agent.", + ) + runner = InMemoryRunner(agent=agent, app_name="test-app") + session = await runner.session_service.create_session( + app_name="test-app", user_id="test" + ) + async for _ in runner.run_async( + user_id="test", + session_id=session.id, + new_message=Content(role="user", parts=[Part(text=prompt)]), + run_config=RunConfig(streaming_mode=StreamingMode.SSE), + ): + pass + return "should not reach" + + +@pytest.mark.asyncio +async def test_streaming_requires_topic(client: Client): + """``stream=True`` fails fast when no streaming topic was configured + on ``TemporalModel``. The error is raised in the workflow before any + streaming activity is scheduled.""" + LLMRegistry.register(StreamingTestModel) + + new_config = client.config() + new_config["plugins"] = [GoogleAdkPlugin()] + client = Client(**new_config) + + async with Worker( + client, + task_queue="adk-streaming-requires-topic", + workflows=[StreamingAdkRequiresTopicWorkflow], + max_cached_workflows=0, + ): + with pytest.raises(WorkflowFailureError) as exc_info: + await client.execute_workflow( + StreamingAdkRequiresTopicWorkflow.run, + "Hi", + id=f"adk-streaming-requires-topic-{uuid.uuid4()}", + task_queue="adk-streaming-requires-topic", + execution_timeout=timedelta(seconds=30), + ) + + assert "streaming_topic" in str(exc_info.value.cause)