Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessage, HumanMessage
from ldai import LDMessage, log
from ldai.providers.runner import Runner
from ldai.providers.types import LDAIMetrics, RunnerResult
Expand All @@ -26,7 +27,9 @@ class LangChainModelRunner(Runner):

def __init__(self, llm: BaseChatModel, config_messages: Optional[List[LDMessage]] = None):
self._llm = llm
self._config_messages: List[LDMessage] = list(config_messages or [])
self._chat_history = InMemoryChatMessageHistory(
messages=cast(List[BaseMessage], convert_messages_to_langchain(config_messages or []))
)

def get_llm(self) -> BaseChatModel:
"""
Expand All @@ -44,26 +47,29 @@ async def run(
"""
Run the LangChain model with the given input.

Prepends any config messages (system prompt, instructions, etc.) stored
at construction time before the user message.

:param input: A string prompt
:param output_type: Optional JSON schema dict requesting structured output.
When provided, ``parsed`` on the returned :class:`RunnerResult` is
populated with the parsed JSON document.
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
``raw`` and (when ``output_type`` is set) ``parsed``.
"""
messages = self._config_messages + [LDMessage(role='user', content=input)]
langchain_messages = self._chat_history.messages + [HumanMessage(content=input)]

if output_type is not None:
return await self._run_structured(messages, output_type)
return await self._run_completion(messages)
result = await self._run_structured(langchain_messages, output_type)
else:
result = await self._run_completion(langchain_messages)

if result.metrics.success and result.content:
self._chat_history.add_user_message(input)
self._chat_history.add_ai_message(result.content)

return result

async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
async def _run_completion(self, messages: List[BaseMessage]) -> RunnerResult:
try:
langchain_messages = convert_messages_to_langchain(messages)
response: BaseMessage = await self._llm.ainvoke(langchain_messages)
response: BaseMessage = await self._llm.ainvoke(messages)
metrics = get_ai_metrics_from_response(response)

content: str = ''
Expand All @@ -90,13 +96,12 @@ async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:

async def _run_structured(
self,
messages: List[LDMessage],
messages: List[BaseMessage],
output_type: Dict[str, Any],
) -> RunnerResult:
try:
langchain_messages = convert_messages_to_langchain(messages)
structured_llm = self._llm.with_structured_output(output_type, include_raw=True)
response = await structured_llm.ainvoke(langchain_messages)
response = await structured_llm.ainvoke(messages)

if not isinstance(response, dict):
log.warning(f'Structured output did not return a dict. Got: {type(response)}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,59 @@ async def test_returns_success_false_when_model_invocation_throws_error(self, mo
assert result.metrics.success is False
assert result.content == ''

@pytest.mark.asyncio
async def test_accumulates_history_across_successful_calls(self, mock_llm):
"""Should include prior exchange in messages on subsequent calls."""
mock_llm.ainvoke = AsyncMock(side_effect=[
AIMessage(content='First response'),
AIMessage(content='Second response'),
])
provider = LangChainModelRunner(mock_llm)

await provider.run('First question')
await provider.run('Second question')

second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0]
roles = [type(m).__name__ for m in second_call_messages]
assert roles == ['HumanMessage', 'AIMessage', 'HumanMessage']
assert second_call_messages[0].content == 'First question'
assert second_call_messages[1].content == 'First response'
assert second_call_messages[2].content == 'Second question'

@pytest.mark.asyncio
async def test_does_not_accumulate_history_on_failed_call(self, mock_llm):
"""Should not add to history when the call fails."""
mock_llm.ainvoke = AsyncMock(side_effect=Exception('Model error'))
provider = LangChainModelRunner(mock_llm)

await provider.run('Hello')

mock_llm.ainvoke = AsyncMock(return_value=AIMessage(content='Recovery'))
await provider.run('Try again')

second_call_messages = mock_llm.ainvoke.call_args_list[0][0][0]
assert len(second_call_messages) == 1
assert second_call_messages[0].content == 'Try again'

@pytest.mark.asyncio
async def test_prepends_config_messages_before_history(self, mock_llm):
"""Should send config messages before history on every call."""
mock_llm.ainvoke = AsyncMock(side_effect=[
AIMessage(content='Answer 1'),
AIMessage(content='Answer 2'),
])
config_messages = [LDMessage(role='system', content='You are helpful.')]
provider = LangChainModelRunner(mock_llm, config_messages=config_messages)

await provider.run('Q1')
await provider.run('Q2')

second_call_messages = mock_llm.ainvoke.call_args_list[1][0][0]
assert second_call_messages[0].content == 'You are helpful.'
assert second_call_messages[1].content == 'Q1'
assert second_call_messages[2].content == 'Answer 1'
assert second_call_messages[3].content == 'Q2'


class TestRunStructured:
"""Tests for run() with structured output."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self._client = client
self._model_name = model_name
self._parameters = parameters
self._config_messages: List[LDMessage] = list(config_messages or [])
self._history: List[LDMessage] = list(config_messages or [])

async def run(
self,
Expand All @@ -43,21 +43,26 @@ async def run(
"""
Run the OpenAI model with the given input.

Prepends any config messages (system prompt, instructions, etc.) stored
at construction time before the user message.

:param input: A string prompt
:param output_type: Optional JSON schema dict requesting structured output.
When provided, ``parsed`` on the returned :class:`RunnerResult` is
populated with the parsed JSON document.
:return: :class:`RunnerResult` containing ``content``, ``metrics``,
``raw`` and (when ``output_type`` is set) ``parsed``.
"""
messages = self._config_messages + [LDMessage(role='user', content=input)]
user_message = LDMessage(role='user', content=input)
messages = self._history + [user_message]

if output_type is not None:
return await self._run_structured(messages, output_type)
return await self._run_completion(messages)
result = await self._run_structured(messages, output_type)
else:
result = await self._run_completion(messages)

if result.metrics.success and result.content:
self._history.append(user_message)
self._history.append(LDMessage(role='assistant', content=result.content))

return result

async def _run_completion(self, messages: List[LDMessage]) -> RunnerResult:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,61 @@ async def test_returns_unsuccessful_response_when_exception_thrown(self, mock_cl
assert result.content == ''
assert result.metrics.success is False

@pytest.mark.asyncio
async def test_accumulates_history_across_successful_calls(self, mock_client):
"""Should include prior exchange in messages on subsequent calls."""
def make_response(text: str):
r = MagicMock()
r.context_wrapper = None
r.choices = [MagicMock()]
r.choices[0].message = MagicMock()
r.choices[0].message.content = text
r.usage = None
return r

mock_client.chat = MagicMock()
mock_client.chat.completions = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=[
make_response('First response'),
make_response('Second response'),
])

provider = OpenAIModelRunner(mock_client, 'gpt-4o', {})
await provider.run('First question')
await provider.run('Second question')

second_call_messages = mock_client.chat.completions.create.call_args_list[1].kwargs['messages']
assert second_call_messages == [
{'role': 'user', 'content': 'First question'},
{'role': 'assistant', 'content': 'First response'},
{'role': 'user', 'content': 'Second question'},
]

@pytest.mark.asyncio
async def test_does_not_accumulate_history_on_failed_call(self, mock_client):
"""Should not add to history when the call fails."""
mock_client.chat = MagicMock()
mock_client.chat.completions = MagicMock()
mock_client.chat.completions.create = AsyncMock(side_effect=Exception('API Error'))

provider = OpenAIModelRunner(mock_client, 'gpt-4o', {})
await provider.run('Hello!')

def make_ok_response():
r = MagicMock()
r.context_wrapper = None
r.choices = [MagicMock()]
r.choices[0].message = MagicMock()
r.choices[0].message.content = 'Recovery'
r.usage = None
return r

mock_client.chat.completions.create = AsyncMock(return_value=make_ok_response())
await provider.run('Try again')

second_call_messages = mock_client.chat.completions.create.call_args.kwargs['messages']
assert second_call_messages == [{'role': 'user', 'content': 'Try again'}]


class TestRunStructured:
"""Tests for the unified run() method (structured-output path)."""
Expand Down
Loading