Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions chatkit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,33 +581,18 @@ async def _process_streaming_impl(
thread = await self.store.load_thread(
request.params.thread_id, context=context
)
items = await self.store.load_thread_items(
thread.id, None, 1, "desc", context
)
tool_call = next(
(
item
for item in items.data
if isinstance(item, ClientToolCallItem)
and item.status == "pending"
),
None,
)
if not tool_call:
raise ValueError(
f"Last thread item in {thread.id} was not a ClientToolCallItem"
)
tool_call = await self._load_pending_client_tool_call(thread, context)
if tool_call:
tool_call.output = request.params.result
tool_call.status = "completed"

tool_call.output = request.params.result
tool_call.status = "completed"
await self.store.save_item(thread.id, tool_call, context=context)

await self.store.save_item(thread.id, tool_call, context=context)

# Safety against dangling pending tool calls if there are
# multiple in a row, which should be impossible, and
# integrations should ultimately filter out pending tool calls
# when creating input response messages.
await self._cleanup_pending_client_tool_call(thread, context)
# Safety against dangling pending tool calls if there are
# multiple in a row, which should be impossible, and
# integrations should ultimately filter out pending tool calls
# when creating input response messages.
await self._cleanup_pending_client_tool_call(thread, context)

async for event in self._process_events(
thread,
Expand Down Expand Up @@ -732,6 +717,17 @@ async def _cleanup_pending_client_tool_call(
thread.id, tool_call.id, context=context
)

async def _load_pending_client_tool_call(
self, thread: ThreadMetadata, context: TContext
) -> ClientToolCallItem | None:
items = await self.store.load_thread_items(
thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
)
for item in items.data:
if isinstance(item, ClientToolCallItem) and item.status == "pending":
return item
return None

async def _process_new_thread_item_respond(
self,
thread: ThreadMetadata,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "openai-chatkit"
version = "1.6.3"
version = "1.6.4"
description = "A ChatKit backend SDK."
readme = "README.md"
requires-python = ">=3.10"
Expand Down
119 changes: 119 additions & 0 deletions tests/test_chatkit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
LockedStatus,
Page,
ProgressUpdateEvent,
SDKHiddenContextItem,
SyncCustomActionResponse,
Thread,
ThreadAddClientToolOutputParams,
Expand Down Expand Up @@ -859,6 +860,124 @@ async def responder(
assert events[1].item.type == "assistant_message"


async def test_add_client_tool_output_finds_pending_tool_call_before_latest_item():
async def responder(
thread: ThreadMetadata, input: UserMessageItem | None, context: Any
) -> AsyncIterator[ThreadStreamEvent]:
if isinstance(input, UserMessageItem):
yield ThreadItemDoneEvent(
item=ClientToolCallItem(
id="msg_1",
created_at=datetime.now(),
name="tool_call_1",
arguments={"arg1": "val1"},
call_id="tool_call_1",
thread_id=thread.id,
),
)
elif input is None:
yield ThreadItemDoneEvent(
item=AssistantMessageItem(
id="msg_2",
content=[
AssistantMessageContent(text="Glad the tool call succeeded!")
],
created_at=datetime.now(),
thread_id=thread.id,
),
)

with make_server(responder) as server:
events = await server.process_streaming(
ThreadsCreateReq(
params=ThreadCreateParams(
input=UserMessageInput(
content=[UserMessageTextContent(text="Hello, world!")],
attachments=[],
inference_options=InferenceOptions(),
)
)
)
)
thread = next(
event.thread for event in events if isinstance(event, ThreadCreatedEvent)
)

await server.store.add_thread_item(
thread.id,
SDKHiddenContextItem(
id="hidden_1",
created_at=datetime.now(),
thread_id=thread.id,
content="The user cancelled the stream.",
),
DEFAULT_CONTEXT,
)

events = await server.process_streaming(
ThreadsAddClientToolOutputReq(
params=ThreadAddClientToolOutputParams(
thread_id=thread.id,
result={"text": "Wow!"},
)
)
)

tool_call = await server.store.load_item(thread.id, "msg_1", DEFAULT_CONTEXT)
assert isinstance(tool_call, ClientToolCallItem)
assert tool_call.status == "completed"
assert tool_call.output == {"text": "Wow!"}
assert len(events) == 2
assert events[0].type == "stream_options"
assert events[1].type == "thread.item.done"
assert events[1].item.type == "assistant_message"


async def test_add_client_tool_output_without_pending_tool_call_continues_inference():
async def responder(
thread: ThreadMetadata, input: UserMessageItem | None, context: Any
) -> AsyncIterator[ThreadStreamEvent]:
if input is None:
yield ThreadItemDoneEvent(
item=AssistantMessageItem(
id="msg_1",
content=[AssistantMessageContent(text="Continued")],
created_at=datetime.now(),
thread_id=thread.id,
),
)

with make_server(responder) as server:
events = await server.process_streaming(
ThreadsCreateReq(
params=ThreadCreateParams(
input=UserMessageInput(
content=[UserMessageTextContent(text="Hello, world!")],
attachments=[],
inference_options=InferenceOptions(),
)
)
)
)
thread = next(
event.thread for event in events if isinstance(event, ThreadCreatedEvent)
)

events = await server.process_streaming(
ThreadsAddClientToolOutputReq(
params=ThreadAddClientToolOutputParams(
thread_id=thread.id,
result={"text": "Wow!"},
)
)
)

assert len(events) == 2
assert events[0].type == "stream_options"
assert events[1].type == "thread.item.done"
assert events[1].item.type == "assistant_message"


async def test_respond_with_tool_status():
async def responder(
thread: ThreadMetadata, input: UserMessageItem | None, context: Any
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading