diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_model_runner.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_model_runner.py index 6b0fc24..2f5514a 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_model_runner.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_model_runner.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast +from langchain_core.chat_history import InMemoryChatMessageHistory from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessage, HumanMessage from ldai import LDMessage, log from ldai.providers.runner import Runner from ldai.providers.types import LDAIMetrics, RunnerResult @@ -26,7 +27,9 @@ class LangChainModelRunner(Runner): def __init__(self, llm: BaseChatModel, config_messages: Optional[List[LDMessage]] = None): self._llm = llm - self._config_messages: List[LDMessage] = list(config_messages or []) + self._chat_history = InMemoryChatMessageHistory( + messages=cast(List[BaseMessage], convert_messages_to_langchain(config_messages or [])) + ) def get_llm(self) -> BaseChatModel: """ @@ -44,9 +47,6 @@ async def run( """ Run the LangChain model with the given input. - Prepends any config messages (system prompt, instructions, etc.) stored - at construction time before the user message. - :param input: A string prompt :param output_type: Optional JSON schema dict requesting structured output. When provided, ``parsed`` on the returned :class:`RunnerResult` is @@ -54,16 +54,22 @@ async def run( :return: :class:`RunnerResult` containing ``content``, ``metrics``, ``raw`` and (when ``output_type`` is set) ``parsed``. """ - messages = self._config_messages + [LDMessage(role='user', content=input)] + langchain_messages = self._chat_history.messages + [HumanMessage(content=input)] if output_type is not None: - return await self._run_structured(messages, output_type) - return await self._run_completion(messages) + result = await self._run_structured(langchain_messages, output_type) + else: + result = await self._run_completion(langchain_messages) + + if result.metrics.success and result.content: + self._chat_history.add_user_message(input) + self._chat_history.add_ai_message(result.content) + + return result - async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult: + async def _run_completion(self, messages: List[BaseMessage]) -> RunnerResult: try: - langchain_messages = convert_messages_to_langchain(messages) - response: BaseMessage = await self._llm.ainvoke(langchain_messages) + response: BaseMessage = await self._llm.ainvoke(messages) metrics = get_ai_metrics_from_response(response) content: str = '' @@ -90,13 +96,12 @@ async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult: async def _run_structured( self, - messages: List[LDMessage], + messages: List[BaseMessage], output_type: Dict[str, Any], ) -> RunnerResult: try: - langchain_messages = convert_messages_to_langchain(messages) structured_llm = self._llm.with_structured_output(output_type, include_raw=True) - response = await structured_llm.ainvoke(langchain_messages) + response = await structured_llm.ainvoke(messages) if not isinstance(response, dict): log.warning(f'Structured output did not return a dict. Got: {type(response)}') diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py index a0ee176..df1c60a 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py @@ -262,6 +262,59 @@ async def test_returns_success_false_when_model_invocation_throws_error(self, mo assert result.metrics.success is False assert result.content == '' + @pytest.mark.asyncio + async def test_accumulates_history_across_successful_calls(self, mock_llm): + """Should include prior exchange in messages on subsequent calls.""" + mock_llm.ainvoke = AsyncMock(side_effect=[ + AIMessage(content='First response'), + AIMessage(content='Second response'), + ]) + provider = LangChainModelRunner(mock_llm) + + await provider.run('First question') + await provider.run('Second question') + + second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0] + roles = [type(m).__name__ for m in second_call_messages] + assert roles == ['HumanMessage', 'AIMessage', 'HumanMessage'] + assert second_call_messages[0].content == 'First question' + assert second_call_messages[1].content == 'First response' + assert second_call_messages[2].content == 'Second question' + + @pytest.mark.asyncio + async def test_does_not_accumulate_history_on_failed_call(self, mock_llm): + """Should not add to history when the call fails.""" + mock_llm.ainvoke = AsyncMock(side_effect=Exception('Model error')) + provider = LangChainModelRunner(mock_llm) + + await provider.run('Hello') + + mock_llm.ainvoke = AsyncMock(return_value=AIMessage(content='Recovery')) + await provider.run('Try again') + + second_call_messages = mock_llm.ainvoke.call_args_list[0][0][0] + assert len(second_call_messages) == 1 + assert second_call_messages[0].content == 'Try again' + + @pytest.mark.asyncio + async def test_prepends_config_messages_before_history(self, mock_llm): + """Should send config messages before history on every call.""" + mock_llm.ainvoke = AsyncMock(side_effect=[ + AIMessage(content='Answer 1'), + AIMessage(content='Answer 2'), + ]) + config_messages = [LDMessage(role='system', content='You are helpful.')] + provider = LangChainModelRunner(mock_llm, config_messages=config_messages) + + await provider.run('Q1') + await provider.run('Q2') + + second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0] + assert second_call_messages[0].content == 'You are helpful.' + assert second_call_messages[1].content == 'Q1' + assert second_call_messages[2].content == 'Answer 1' + assert second_call_messages[3].content == 'Q2' + class TestRunStructured: """Tests for run() with structured output.""" diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_model_runner.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_model_runner.py index 1ef775d..02bca9a 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_model_runner.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_model_runner.py @@ -33,7 +33,7 @@ def __init__( self._client = client self._model_name = model_name self._parameters = parameters - self._config_messages: List[LDMessage] = list(config_messages or []) + self._history: List[LDMessage] = list(config_messages or []) async def run( self, @@ -43,9 +43,6 @@ async def run( """ Run the OpenAI model with the given input. - Prepends any config messages (system prompt, instructions, etc.) stored - at construction time before the user message. - :param input: A string prompt :param output_type: Optional JSON schema dict requesting structured output. When provided, ``parsed`` on the returned :class:`RunnerResult` is @@ -53,11 +50,19 @@ async def run( :return: :class:`RunnerResult` containing ``content``, ``metrics``, ``raw`` and (when ``output_type`` is set) ``parsed``. """ - messages = self._config_messages + [LDMessage(role='user', content=input)] + user_message = LDMessage(role='user', content=input) + messages = self._history + [user_message] if output_type is not None: - return await self._run_structured(messages, output_type) - return await self._run_completion(messages) + result = await self._run_structured(messages, output_type) + else: + result = await self._run_completion(messages) + + if result.metrics.success and result.content: + self._history.append(user_message) + self._history.append(LDMessage(role='assistant', content=result.content)) + + return result async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult: try: diff --git a/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py b/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py index 4a1eb5f..edfa840 100644 --- a/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py +++ b/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py @@ -204,6 +204,61 @@ async def test_returns_unsuccessful_response_when_exception_thrown(self, mock_cl assert result.content == '' assert result.metrics.success is False + @pytest.mark.asyncio + async def test_accumulates_history_across_successful_calls(self, mock_client): + """Should include prior exchange in messages on subsequent calls.""" + def make_response(text: str): + r = MagicMock() + r.context_wrapper = None + r.choices = [MagicMock()] + r.choices[0].message = MagicMock() + r.choices[0].message.content = text + r.usage = None + return r + + mock_client.chat = MagicMock() + mock_client.chat.completions = MagicMock() + mock_client.chat.completions.create = AsyncMock(side_effect=[ + make_response('First response'), + make_response('Second response'), + ]) + + provider = OpenAIModelRunner(mock_client, 'gpt-4o', {}) + await provider.run('First question') + await provider.run('Second question') + + second_call_messages = mock_client.chat.completions.create.call_args_list[1].kwargs['messages'] + assert second_call_messages == [ + {'role': 'user', 'content': 'First question'}, + {'role': 'assistant', 'content': 'First response'}, + {'role': 'user', 'content': 'Second question'}, + ] + + @pytest.mark.asyncio + async def test_does_not_accumulate_history_on_failed_call(self, mock_client): + """Should not add to history when the call fails.""" + mock_client.chat = MagicMock() + mock_client.chat.completions = MagicMock() + mock_client.chat.completions.create = AsyncMock(side_effect=Exception('API Error')) + + provider = OpenAIModelRunner(mock_client, 'gpt-4o', {}) + await provider.run('Hello!') + + def make_ok_response(): + r = MagicMock() + r.context_wrapper = None + r.choices = [MagicMock()] + r.choices[0].message = MagicMock() + r.choices[0].message.content = 'Recovery' + r.usage = None + return r + + mock_client.chat.completions.create = AsyncMock(return_value=make_ok_response()) + await provider.run('Try again') + + second_call_messages = mock_client.chat.completions.create.call_args.kwargs['messages'] + assert second_call_messages == [{'role': 'user', 'content': 'Try again'}] + class TestRunStructured: """Tests for the unified run() method (structured-output path)."""