From 74a4760ed609c2a999e7ffb746a2d59083a9c770 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Mon, 27 Apr 2026 18:36:49 +0000 Subject: [PATCH 1/6] feat: add context compaction strategies for react framework Adds CompactionStrategy abstraction and KeepLastN implementation to mellea/stdlib/compaction.py, wires an optional compaction parameter into the react() loop, and adds full test coverage in test/stdlib/test_compaction.py. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/compaction.py | 325 ++++++++++++++++++++++++++++ mellea/stdlib/frameworks/react.py | 13 ++ test/stdlib/test_compaction.py | 344 ++++++++++++++++++++++++++++++ 3 files changed, 682 insertions(+) create mode 100644 mellea/stdlib/compaction.py create mode 100644 test/stdlib/test_compaction.py diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/compaction.py new file mode 100644 index 000000000..2f9636fb7 --- /dev/null +++ b/mellea/stdlib/compaction.py @@ -0,0 +1,325 @@ +"""Context compaction strategies for the ReACT framework. + +Provides modular, callable strategy objects to compact a ``ChatContext`` that +has grown too large during a react loop. Three strategies are available: + +- ``ClearAll`` — discard the entire conversation body, keeping only the prefix + (everything up to and including the ``ReactInitiator``). +- ``KeepLastN`` — keep the prefix plus the *n* most recent body components. +- ``LLMSummarize`` — ask the backend to summarize old body components into a + single ``Message``, then keep the last *n* body components verbatim. + +All strategies preserve the **prefix** (every component up to and including the +first ``ReactInitiator``) so the model retains its goal and tool definitions. + +Example:: + + from mellea.stdlib.compaction import KeepLastN + from mellea.stdlib.frameworks.react import react + + await react( + goal="...", + context=ChatContext(), + backend=m.backend, + tools=[search_tool], + compaction=KeepLastN(keep_n=5, threshold=20), + ) +""" + +from __future__ import annotations + +import abc + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Component, ModelOutputThunk +from mellea.core.utils import MelleaLogger +from mellea.stdlib.components.chat import Message, ToolMessage +from mellea.stdlib.components.react import ReactInitiator +from mellea.stdlib.context import ChatContext + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def rebuild_chat_context( + components: list[Component | CBlock], *, window_size: int | None = None +) -> ChatContext: + """Build a fresh ``ChatContext`` from an ordered list of components. + + Args: + components: Components to add, in chronological order. + window_size: Optional sliding-window size for the new context. + + Returns: + A new ``ChatContext`` containing all *components*. + """ + ctx = ChatContext(window_size=window_size) + for c in components: + ctx = ctx.add(c) + return ctx + + +def _find_prefix_end(components: list[Component | CBlock]) -> int: + """Return the index *after* the first ``ReactInitiator``. + + Everything in ``components[:idx]`` is the prefix that must be preserved by + every compaction strategy. Returns 0 when no ``ReactInitiator`` is found. + """ + for i, c in enumerate(components): + if isinstance(c, ReactInitiator): + return i + 1 + return 0 + + +# --------------------------------------------------------------------------- +# Abstract base +# --------------------------------------------------------------------------- + + +class CompactionStrategy(abc.ABC): + """Abstract base class for context compaction strategies. + + Each strategy carries a ``threshold`` — the component count above which + compaction should fire. The :meth:`should_compact` helper checks this so + callers don't need to track the threshold separately. + + Subclasses implement :meth:`compact` which receives the current + ``ChatContext`` and returns a compacted copy. The method is ``async`` + so that strategies requiring LLM calls (e.g. ``LLMSummarize``) work + transparently; synchronous strategies simply never ``await``. + + Args: + threshold (int): Trigger compaction when the number of context + components exceeds this value. + """ + + def __init__(self, *, threshold: int = 0) -> None: + """Initialize with the component-count threshold.""" + self.threshold = threshold + + def should_compact(self, context: ChatContext) -> bool: + """Return ``True`` when *context* exceeds the configured threshold. + + Args: + context: The context to check. + + Returns: + ``True`` if the number of components exceeds ``self.threshold`` + and ``self.threshold`` is greater than 0. + """ + return self.threshold > 0 and len(context.as_list()) > self.threshold + + async def maybe_compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Compact *context* only if it exceeds the threshold, otherwise return it unchanged. + + Args: + context: The context to check and potentially compact. + backend: The backend (forwarded to :meth:`compact`). + goal: The react goal string (forwarded to :meth:`compact`). + + Returns: + A compacted ``ChatContext`` if the threshold was exceeded, + or the original *context* unchanged. + """ + if self.should_compact(context): + return await self.compact(context, backend=backend, goal=goal) + return context + + @abc.abstractmethod + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a compacted copy of *context*. + + Args: + context: The context to compact. + backend: The backend (required by ``LLMSummarize``). + goal: The react goal string (required by ``LLMSummarize``). + + Returns: + A new, compacted ``ChatContext``. + """ + + +# --------------------------------------------------------------------------- +# Concrete strategies +# --------------------------------------------------------------------------- + + +class ClearAll(CompactionStrategy): + """Discard the entire conversation body, keeping only the prefix. + + The prefix is everything up to and including the first ``ReactInitiator``. + + Args: + threshold (int): Trigger compaction when context exceeds this many components. + """ + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context containing only the prefix.""" + components = context.as_list() + prefix_end = _find_prefix_end(components) + compacted = components[:prefix_end] + + MelleaLogger.get_logger().info( + f"ClearAll: compacted context from {len(components)} to " + f"{len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) + + +class KeepLastN(CompactionStrategy): + """Keep the prefix plus the last *keep_n* body components. + + Args: + keep_n (int): Number of recent body components to retain. + threshold (int): Trigger compaction when context exceeds this many components. + """ + + def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: + """Initialize with the number of recent body components to keep.""" + super().__init__(threshold=threshold) + self.keep_n = keep_n + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context with the prefix and the last *keep_n* body components.""" + components = context.as_list() + prefix_end = _find_prefix_end(components) + prefix = components[:prefix_end] + body = components[prefix_end:] + + if len(body) <= self.keep_n: + return context # nothing to compact + + compacted = prefix + body[-self.keep_n :] + + MelleaLogger.get_logger().info( + f"KeepLastN(keep_n={self.keep_n}): compacted context from " + f"{len(components)} to {len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) + + +class LLMSummarize(CompactionStrategy): + """Summarize old body components with the LLM, keep last *keep_n* verbatim. + + Requires ``backend`` and ``goal`` to be passed to :meth:`compact`. + + Args: + keep_n (int): Number of recent body components to retain verbatim. + threshold (int): Trigger compaction when context exceeds this many components. + """ + + def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: + """Initialize with the number of recent body components to keep.""" + super().__init__(threshold=threshold) + self.keep_n = keep_n + + async def compact( + self, + context: ChatContext, + *, + backend: Backend | None = None, + goal: str | None = None, + ) -> ChatContext: + """Return a context with the prefix, an LLM summary, and recent body components. + + Raises: + ValueError: If *backend* or *goal* are not provided. + """ + if backend is None or goal is None: + raise ValueError( + "LLMSummarize requires both 'backend' and 'goal' arguments" + ) + + from mellea.stdlib import functional as mfuncs + from mellea.stdlib.context import SimpleContext + + components = context.as_list() + prefix_end = _find_prefix_end(components) + prefix = components[:prefix_end] + body = components[prefix_end:] + + if len(body) <= self.keep_n: + return context # nothing to compact + + old = body[: -self.keep_n] if self.keep_n > 0 else body + recent = body[-self.keep_n :] if self.keep_n > 0 else [] + + # Build a textual representation of old components for summarization. + context_lines: list[str] = [] + for c in old: + if isinstance(c, ToolMessage): + context_lines.append(f"tool ({c.name}): {c.content}") + elif isinstance(c, Message): + context_lines.append(f"{c.role}: {c.content}") + elif isinstance(c, ModelOutputThunk): + context_lines.append(f"assistant: {c.value}") + elif isinstance(c, CBlock): + context_lines.append(str(c)) + else: + context_lines.append(str(getattr(c, "content", c))) + + summary_prompt = ( + "You are summarizing research progress to maintain context " + "within token limits.\n\n" + f"GOAL: {goal}\n\n" + "Provide a comprehensive summary of the research context below. " + "Your summary should:\n" + "- Preserve ALL specific facts, numbers, names, URLs, and search " + "queries found\n" + "- Note which tools were called and what results were obtained\n" + "- Highlight key findings and any dead ends encountered\n" + "- Be structured clearly so the research can continue seamlessly" + "\n\nContext to summarize:\n" + f"{chr(10).join(context_lines)}" + ) + + summary_action = Message(role="user", content=summary_prompt) + result, _ = await mfuncs.aact( + action=summary_action, + context=SimpleContext(), + backend=backend, + requirements=[], + strategy=None, + await_result=True, + ) + + summary_text = result.value or "" + summary_message = Message( + role="user", + content=( + f"[CONTEXT SUMMARY]\n{summary_text}\n\nContinue working on: {goal}" + ), + ) + + compacted = [*prefix, summary_message, *recent] + + MelleaLogger.get_logger().info( + f"LLMSummarize(keep_n={self.keep_n}): compacted context from " + f"{len(components)} to {len(compacted)} components" + ) + return rebuild_chat_context(compacted, window_size=context._window_size) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 117af4866..7f39bba27 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -15,6 +15,7 @@ from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document +from mellea.stdlib.compaction import CompactionStrategy from mellea.stdlib.components.chat import ToolMessage from mellea.stdlib.components.react import ( MELLEA_FINALIZER_TOOL, @@ -36,6 +37,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + compaction: CompactionStrategy | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,6 +49,10 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + compaction: an optional ``CompactionStrategy`` to apply when the context + exceeds the strategy's configured threshold + (e.g. ``KeepLastN(keep_n=5, threshold=20)``). + Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -79,6 +85,13 @@ async def react( turn_num = 0 while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 + + # -- Context compaction -- + if compaction is not None: + context = await compaction.maybe_compact( + context, backend=backend, goal=goal + ) + MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( diff --git a/test/stdlib/test_compaction.py b/test/stdlib/test_compaction.py new file mode 100644 index 000000000..9b2ff455d --- /dev/null +++ b/test/stdlib/test_compaction.py @@ -0,0 +1,344 @@ +"""Unit and integration tests for mellea.stdlib.compaction.""" + +from collections.abc import Sequence +from dataclasses import dataclass + +import pytest + +from mellea.core.backend import Backend, BaseModelSubclass +from mellea.core.base import ( + C, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, + ModelToolCall, +) +from mellea.stdlib.compaction import ( + ClearAll, + KeepLastN, + LLMSummarize, + _find_prefix_end, + rebuild_chat_context, +) +from mellea.stdlib.components.chat import Message +from mellea.stdlib.components.react import ( + MELLEA_FINALIZER_TOOL, + ReactInitiator, + _mellea_finalize_tool, +) +from mellea.stdlib.context import ChatContext +from mellea.stdlib.frameworks.react import react + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_context(components: list[Component | CBlock]) -> ChatContext: + """Build a ChatContext from a list of components.""" + ctx = ChatContext() + for c in components: + ctx = ctx.add(c) + return ctx + + +def _msg(role: Message.Role, content: str) -> Message: + return Message(role=role, content=content) + + +# --------------------------------------------------------------------------- +# rebuild_chat_context +# --------------------------------------------------------------------------- + + +class TestRebuildChatContext: + def test_empty(self): + ctx = rebuild_chat_context([]) + assert ctx.as_list() == [] + + def test_round_trip(self): + components = [_msg("user", "hello"), _msg("assistant", "hi")] + ctx = rebuild_chat_context(components) + result = ctx.as_list() + assert len(result) == 2 + assert all(isinstance(c, Message) for c in result) + + def test_preserves_window_size(self): + ctx = rebuild_chat_context([_msg("user", "a")], window_size=3) + assert ctx._window_size == 3 + + +# --------------------------------------------------------------------------- +# _find_prefix_end +# --------------------------------------------------------------------------- + + +class TestFindPrefixEnd: + def test_no_initiator(self): + components = [_msg("user", "a"), _msg("assistant", "b")] + assert _find_prefix_end(components) == 0 + + def test_initiator_at_start(self): + components = [ReactInitiator("goal", []), _msg("user", "a")] + assert _find_prefix_end(components) == 1 + + def test_initiator_after_system_msg(self): + components = [ + _msg("system", "sys"), + ReactInitiator("goal", []), + _msg("user", "a"), + ] + assert _find_prefix_end(components) == 2 + + +# --------------------------------------------------------------------------- +# should_compact +# --------------------------------------------------------------------------- + + +class TestShouldCompact: + def test_below_threshold(self): + ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) + strategy = KeepLastN(keep_n=1, threshold=5) + assert strategy.should_compact(ctx) is False + + def test_above_threshold(self): + ctx = _build_context([_msg("user", str(i)) for i in range(10)]) + strategy = KeepLastN(keep_n=1, threshold=5) + assert strategy.should_compact(ctx) is True + + def test_zero_threshold_never_triggers(self): + ctx = _build_context([_msg("user", str(i)) for i in range(10)]) + strategy = KeepLastN(keep_n=1, threshold=0) + assert strategy.should_compact(ctx) is False + + +# --------------------------------------------------------------------------- +# ClearAll +# --------------------------------------------------------------------------- + + +class TestClearAll: + @pytest.mark.asyncio + async def test_keeps_only_prefix(self): + initiator = ReactInitiator("find the answer", []) + components = [initiator, _msg("user", "a"), _msg("assistant", "b")] + ctx = _build_context(components) + + result = await ClearAll().compact(ctx) + result_list = result.as_list() + assert len(result_list) == 1 + assert isinstance(result_list[0], ReactInitiator) + + @pytest.mark.asyncio + async def test_empty_body_is_noop(self): + initiator = ReactInitiator("goal", []) + ctx = _build_context([initiator]) + + result = await ClearAll().compact(ctx) + assert len(result.as_list()) == 1 + + +# --------------------------------------------------------------------------- +# KeepLastN +# --------------------------------------------------------------------------- + + +class TestKeepLastN: + @pytest.mark.asyncio + async def test_keeps_prefix_and_last_n(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", str(i)) for i in range(10)] + ctx = _build_context([initiator, *body]) + + result = await KeepLastN(keep_n=3).compact(ctx) + result_list = result.as_list() + assert len(result_list) == 4 # 1 prefix + 3 body + assert isinstance(result_list[0], ReactInitiator) + # Last 3 body messages + for i, c in enumerate(result_list[1:]): + assert isinstance(c, Message) + assert c.content == str(7 + i) + + @pytest.mark.asyncio + async def test_fewer_than_n_is_noop(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", "a"), _msg("assistant", "b")] + ctx = _build_context([initiator, *body]) + + result = await KeepLastN(keep_n=5).compact(ctx) + # Should return original context unchanged + assert result is ctx + + @pytest.mark.asyncio + async def test_preserves_window_size(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", str(i)) for i in range(10)] + ctx = rebuild_chat_context([initiator, *body], window_size=7) + + result = await KeepLastN(keep_n=2).compact(ctx) + assert result._window_size == 7 + + +# --------------------------------------------------------------------------- +# LLMSummarize +# --------------------------------------------------------------------------- + + +@dataclass +class _ScriptedTurn: + """A single scripted backend response.""" + + value: str + tool_calls: dict[str, ModelToolCall] | None = None + + +class ScriptedBackend(Backend): + """Fake backend returning pre-scripted responses.""" + + def __init__(self, script: list[_ScriptedTurn]) -> None: + self._script = iter(script) + + async def _generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + turn = next(self._script) + mot: ModelOutputThunk = ModelOutputThunk( + value=turn.value, tool_calls=turn.tool_calls + ) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +class TestLLMSummarize: + @pytest.mark.asyncio + async def test_raises_without_backend(self): + ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) + with pytest.raises(ValueError, match="backend"): + await LLMSummarize(keep_n=0).compact(ctx) + + @pytest.mark.asyncio + async def test_raises_without_goal(self): + ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) + backend = ScriptedBackend([]) + with pytest.raises(ValueError, match="goal"): + await LLMSummarize(keep_n=0).compact(ctx, backend=backend) + + @pytest.mark.asyncio + async def test_summarizes_old_keeps_recent(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", f"msg-{i}") for i in range(6)] + ctx = _build_context([initiator, *body]) + + # The backend will return one summary when the summarization prompt is sent + backend = ScriptedBackend([_ScriptedTurn(value="Summary of old messages")]) + + result = await LLMSummarize(keep_n=2).compact(ctx, backend=backend, goal="goal") + result_list = result.as_list() + + # prefix (1) + summary message (1) + last 2 body = 4 + assert len(result_list) == 4 + assert isinstance(result_list[0], ReactInitiator) + # Summary message + assert isinstance(result_list[1], Message) + assert "[CONTEXT SUMMARY]" in result_list[1].content + # Recent messages preserved + assert result_list[2].content == "msg-4" + assert result_list[3].content == "msg-5" + + @pytest.mark.asyncio + async def test_fewer_than_n_is_noop(self): + initiator = ReactInitiator("goal", []) + body = [_msg("user", "a")] + ctx = _build_context([initiator, *body]) + backend = ScriptedBackend([]) + + result = await LLMSummarize(keep_n=5).compact(ctx, backend=backend, goal="goal") + assert result is ctx + + +# --------------------------------------------------------------------------- +# Integration: react() with compaction +# --------------------------------------------------------------------------- + + +from mellea.backends.tools import MelleaTool + + +def _make_tool(name: str, return_value: str = "tool_result") -> MelleaTool: + def _fn() -> str: + return return_value + + return MelleaTool.from_callable(_fn, name=name) + + +def _final_answer_call(answer: str = "42") -> _ScriptedTurn: + tool = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) + tc = ModelToolCall(name=MELLEA_FINALIZER_TOOL, func=tool, args={"answer": answer}) + return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) + + +def _tool_call_turn( + tool_name: str, tool: MelleaTool, thought: str = "thinking..." +) -> _ScriptedTurn: + tc = ModelToolCall(name=tool_name, func=tool, args={}) + return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + + +class TestReactWithCompaction: + @pytest.mark.asyncio + @pytest.mark.integration + async def test_compaction_triggers_during_react(self): + """Compaction fires when context exceeds threshold, loop still completes.""" + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _tool_call_turn("search", search, "step 3"), + _final_answer_call("done"), + ] + ) + + result, _ctx = await react( + goal="find info", + context=ChatContext(), + backend=backend, + tools=[search], + loop_budget=10, + compaction=KeepLastN(keep_n=3, threshold=6), + ) + assert result.value == "done" + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_no_compaction_when_disabled(self): + """Without compaction params, react behaves identically to before.""" + backend = ScriptedBackend([_final_answer_call("42")]) + result, _ = await react( + goal="answer", + context=ChatContext(), + backend=backend, + tools=None, + loop_budget=5, + ) + assert result.value == "42" From 455c93d970eed100cf234ba957094fd247849c38 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Tue, 28 Apr 2026 18:58:21 +0000 Subject: [PATCH 2/6] refactor: express compaction threshold as token count Switches `CompactionStrategy.threshold` from a component-count trigger to a token-count trigger, read from the most recent `ModelOutputThunk.usage` populated by the backend. This aligns compaction with the real constraint (context size) and sidesteps per-backend tokenizer dependencies by using provider-reported usage; the trade-off is a one-turn lag since usage is recorded at the end of each model call. Also reorders the react loop so compaction runs after the final-answer check, skipping wasted work (and a wasted LLM call for LLMSummarize) on terminal turns. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/compaction.py | 67 +++++++++++++++++++++------ mellea/stdlib/frameworks/react.py | 12 ++--- test/stdlib/test_compaction.py | 77 ++++++++++++++++++++++++++----- 3 files changed, 125 insertions(+), 31 deletions(-) diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/compaction.py index 2f9636fb7..1f98d2eef 100644 --- a/mellea/stdlib/compaction.py +++ b/mellea/stdlib/compaction.py @@ -17,12 +17,13 @@ from mellea.stdlib.compaction import KeepLastN from mellea.stdlib.frameworks.react import react + # Compact once the most recent model call reports > 8000 prompt+completion tokens. await react( goal="...", context=ChatContext(), backend=m.backend, tools=[search_tool], - compaction=KeepLastN(keep_n=5, threshold=20), + compaction=KeepLastN(keep_n=5, threshold=8000), ) """ @@ -72,6 +73,26 @@ def _find_prefix_end(components: list[Component | CBlock]) -> int: return 0 +def _last_usage_tokens(context: ChatContext) -> int | None: + """Return ``total_tokens`` from the most recent ``ModelOutputThunk`` with usage. + + Walks *context* back-to-front looking for a ``ModelOutputThunk`` whose + ``usage`` dict has been populated by a backend's ``post_processing``. + Falls back to ``prompt_tokens + completion_tokens`` when ``total_tokens`` + is missing. Returns ``None`` if no usable token count can be recovered — + typically the case before the first model call completes. + """ + for c in reversed(context.as_list()): + if isinstance(c, ModelOutputThunk) and c.usage is not None: + total = c.usage.get("total_tokens") + if total is None: + pt = c.usage.get("prompt_tokens") or 0 + ct = c.usage.get("completion_tokens") or 0 + total = pt + ct + return total if total and total > 0 else None + return None + + # --------------------------------------------------------------------------- # Abstract base # --------------------------------------------------------------------------- @@ -80,9 +101,16 @@ def _find_prefix_end(components: list[Component | CBlock]) -> int: class CompactionStrategy(abc.ABC): """Abstract base class for context compaction strategies. - Each strategy carries a ``threshold`` — the component count above which - compaction should fire. The :meth:`should_compact` helper checks this so - callers don't need to track the threshold separately. + Each strategy carries a ``threshold`` — the token count above which + compaction should fire. The :meth:`should_compact` helper reads the + most recent ``ModelOutputThunk.usage`` populated by the backend and + compares its total token count to ``threshold``. + + Because ``usage`` is recorded when a model call completes, the measured + token count reflects the context as of the *previous* turn — any + components appended since (e.g. a tool response) are not yet included. + In practice this one-turn lag is negligible unless a single tool call + adds a very large payload. Subclasses implement :meth:`compact` which receives the current ``ChatContext`` and returns a compacted copy. The method is ``async`` @@ -90,25 +118,35 @@ class CompactionStrategy(abc.ABC): transparently; synchronous strategies simply never ``await``. Args: - threshold (int): Trigger compaction when the number of context - components exceeds this value. + threshold (int): Trigger compaction when the most recent thunk's + total token usage exceeds this value. ``0`` disables compaction. """ def __init__(self, *, threshold: int = 0) -> None: - """Initialize with the component-count threshold.""" + """Initialize with the token-count threshold.""" self.threshold = threshold def should_compact(self, context: ChatContext) -> bool: - """Return ``True`` when *context* exceeds the configured threshold. + """Return ``True`` when the last thunk's token usage exceeds ``threshold``. + + Reads ``total_tokens`` from the most recent ``ModelOutputThunk.usage`` + in *context*. Returns ``False`` when no thunk with usage is present + (e.g. before the first model call) or when ``threshold`` is not + positive. Args: context: The context to check. Returns: - ``True`` if the number of components exceeds ``self.threshold`` + ``True`` if the recovered token count exceeds ``self.threshold`` and ``self.threshold`` is greater than 0. """ - return self.threshold > 0 and len(context.as_list()) > self.threshold + if self.threshold <= 0: + return False + tokens = _last_usage_tokens(context) + if tokens is None: + return False + return tokens > self.threshold async def maybe_compact( self, @@ -163,7 +201,8 @@ class ClearAll(CompactionStrategy): The prefix is everything up to and including the first ``ReactInitiator``. Args: - threshold (int): Trigger compaction when context exceeds this many components. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. """ async def compact( @@ -190,7 +229,8 @@ class KeepLastN(CompactionStrategy): Args: keep_n (int): Number of recent body components to retain. - threshold (int): Trigger compaction when context exceeds this many components. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. """ def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: @@ -230,7 +270,8 @@ class LLMSummarize(CompactionStrategy): Args: keep_n (int): Number of recent body components to retain verbatim. - threshold (int): Trigger compaction when context exceeds this many components. + threshold (int): Trigger compaction when the most recent thunk's total + token usage exceeds this value. """ def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 7f39bba27..baf876ea0 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -86,12 +86,6 @@ async def react( while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 - # -- Context compaction -- - if compaction is not None: - context = await compaction.maybe_compact( - context, backend=backend, goal=goal - ) - MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( @@ -140,4 +134,10 @@ async def react( step._underlying_value = str(tool_responses[0].content) return step, context + # Compact after the final-answer check so terminal turns skip it. + if compaction is not None: + context = await compaction.maybe_compact( + context, backend=backend, goal=goal + ) + raise RuntimeError(f"could not complete react loop in {loop_budget} iterations") diff --git a/test/stdlib/test_compaction.py b/test/stdlib/test_compaction.py index 9b2ff455d..3f4650e0d 100644 --- a/test/stdlib/test_compaction.py +++ b/test/stdlib/test_compaction.py @@ -20,6 +20,7 @@ KeepLastN, LLMSummarize, _find_prefix_end, + _last_usage_tokens, rebuild_chat_context, ) from mellea.stdlib.components.chat import Message @@ -48,6 +49,17 @@ def _msg(role: Message.Role, content: str) -> Message: return Message(role=role, content=content) +def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict.""" + mot = ModelOutputThunk(value=value) + mot.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + # --------------------------------------------------------------------------- # rebuild_chat_context # --------------------------------------------------------------------------- @@ -98,19 +110,48 @@ def test_initiator_after_system_msg(self): # --------------------------------------------------------------------------- +class TestLastUsageTokens: + def test_no_thunk_returns_none(self): + ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) + assert _last_usage_tokens(ctx) is None + + def test_thunk_without_usage_returns_none(self): + ctx = _build_context([_msg("user", "a"), ModelOutputThunk(value="b")]) + assert _last_usage_tokens(ctx) is None + + def test_reads_total_tokens(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=150)]) + assert _last_usage_tokens(ctx) == 150 + + def test_falls_back_to_prompt_plus_completion(self): + mot = ModelOutputThunk(value="x") + mot.usage = {"prompt_tokens": 40, "completion_tokens": 20} + ctx = _build_context([_msg("user", "a"), mot]) + assert _last_usage_tokens(ctx) == 60 + + def test_uses_most_recent_thunk(self): + ctx = _build_context([_thunk(100), _msg("user", "x"), _thunk(500)]) + assert _last_usage_tokens(ctx) == 500 + + class TestShouldCompact: - def test_below_threshold(self): + def test_no_thunk_does_not_trigger(self): ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) - strategy = KeepLastN(keep_n=1, threshold=5) + strategy = KeepLastN(keep_n=1, threshold=100) + assert strategy.should_compact(ctx) is False + + def test_below_threshold(self): + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=50)]) + strategy = KeepLastN(keep_n=1, threshold=100) assert strategy.should_compact(ctx) is False def test_above_threshold(self): - ctx = _build_context([_msg("user", str(i)) for i in range(10)]) - strategy = KeepLastN(keep_n=1, threshold=5) + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=500)]) + strategy = KeepLastN(keep_n=1, threshold=100) assert strategy.should_compact(ctx) is True def test_zero_threshold_never_triggers(self): - ctx = _build_context([_msg("user", str(i)) for i in range(10)]) + ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=10_000)]) strategy = KeepLastN(keep_n=1, threshold=0) assert strategy.should_compact(ctx) is False @@ -193,6 +234,7 @@ class _ScriptedTurn: value: str tool_calls: dict[str, ModelToolCall] | None = None + total_tokens: int | None = None class ScriptedBackend(Backend): @@ -215,6 +257,12 @@ async def _generate_from_context( value=turn.value, tool_calls=turn.tool_calls ) mot._generate_log = GenerateLog(is_final_result=True) + if turn.total_tokens is not None: + mot.usage = { + "prompt_tokens": turn.total_tokens, + "completion_tokens": 0, + "total_tokens": turn.total_tokens, + } return mot, ctx.add(action).add(mot) async def generate_from_raw( @@ -298,23 +346,28 @@ def _final_answer_call(answer: str = "42") -> _ScriptedTurn: def _tool_call_turn( - tool_name: str, tool: MelleaTool, thought: str = "thinking..." + tool_name: str, + tool: MelleaTool, + thought: str = "thinking...", + total_tokens: int | None = None, ) -> _ScriptedTurn: tc = ModelToolCall(name=tool_name, func=tool, args={}) - return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + return _ScriptedTurn( + value=thought, tool_calls={tool_name: tc}, total_tokens=total_tokens + ) class TestReactWithCompaction: @pytest.mark.asyncio @pytest.mark.integration async def test_compaction_triggers_during_react(self): - """Compaction fires when context exceeds threshold, loop still completes.""" + """Compaction fires when last thunk's token usage exceeds threshold.""" search = _make_tool("search", "found it") backend = ScriptedBackend( [ - _tool_call_turn("search", search, "step 1"), - _tool_call_turn("search", search, "step 2"), - _tool_call_turn("search", search, "step 3"), + _tool_call_turn("search", search, "step 1", total_tokens=200), + _tool_call_turn("search", search, "step 2", total_tokens=200), + _tool_call_turn("search", search, "step 3", total_tokens=200), _final_answer_call("done"), ] ) @@ -325,7 +378,7 @@ async def test_compaction_triggers_during_react(self): backend=backend, tools=[search], loop_budget=10, - compaction=KeepLastN(keep_n=3, threshold=6), + compaction=KeepLastN(keep_n=3, threshold=100), ) assert result.value == "done" From ca7bea1e9ee428c101aa739d2d4fe7f5f84a5c35 Mon Sep 17 00:00:00 2001 From: ramon-astudillo Date: Thu, 30 Apr 2026 13:18:18 -0400 Subject: [PATCH 3/6] Fix mot.generation.usage --- mellea/stdlib/compaction.py | 8 ++++---- mellea/stdlib/frameworks/react.py | 1 - test/stdlib/test_compaction.py | 10 ++++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/compaction.py index 1f98d2eef..20b60f336 100644 --- a/mellea/stdlib/compaction.py +++ b/mellea/stdlib/compaction.py @@ -83,11 +83,11 @@ def _last_usage_tokens(context: ChatContext) -> int | None: typically the case before the first model call completes. """ for c in reversed(context.as_list()): - if isinstance(c, ModelOutputThunk) and c.usage is not None: - total = c.usage.get("total_tokens") + if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: + total = c.generation.usage.get("total_tokens") if total is None: - pt = c.usage.get("prompt_tokens") or 0 - ct = c.usage.get("completion_tokens") or 0 + pt = c.generation.usage.get("prompt_tokens") or 0 + ct = c.generation.usage.get("completion_tokens") or 0 total = pt + ct return total if total and total > 0 else None return None diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index baf876ea0..81dc04146 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -53,7 +53,6 @@ async def react( exceeds the strategy's configured threshold (e.g. ``KeepLastN(keep_n=5, threshold=20)``). - Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. diff --git a/test/stdlib/test_compaction.py b/test/stdlib/test_compaction.py index 3f4650e0d..076faa7f6 100644 --- a/test/stdlib/test_compaction.py +++ b/test/stdlib/test_compaction.py @@ -5,6 +5,7 @@ import pytest +from mellea.backends.tools import MelleaTool from mellea.core.backend import Backend, BaseModelSubclass from mellea.core.base import ( C, @@ -52,7 +53,7 @@ def _msg(role: Message.Role, content: str) -> Message: def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: """Build a ModelOutputThunk with a populated usage dict.""" mot = ModelOutputThunk(value=value) - mot.usage = { + mot.generation.usage = { "prompt_tokens": total_tokens, "completion_tokens": 0, "total_tokens": total_tokens, @@ -125,7 +126,7 @@ def test_reads_total_tokens(self): def test_falls_back_to_prompt_plus_completion(self): mot = ModelOutputThunk(value="x") - mot.usage = {"prompt_tokens": 40, "completion_tokens": 20} + mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} ctx = _build_context([_msg("user", "a"), mot]) assert _last_usage_tokens(ctx) == 60 @@ -258,7 +259,7 @@ async def _generate_from_context( ) mot._generate_log = GenerateLog(is_final_result=True) if turn.total_tokens is not None: - mot.usage = { + mot.generation.usage = { "prompt_tokens": turn.total_tokens, "completion_tokens": 0, "total_tokens": turn.total_tokens, @@ -329,9 +330,6 @@ async def test_fewer_than_n_is_noop(self): # --------------------------------------------------------------------------- -from mellea.backends.tools import MelleaTool - - def _make_tool(name: str, return_value: str = "tool_result") -> MelleaTool: def _fn() -> str: return return_value From 643f5a56cbc0a692683b9d5dde3d58d3ca543285 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Thu, 30 Apr 2026 22:19:03 +0000 Subject: [PATCH 4/6] refactor: relocate compaction module into frameworks package Move the compaction strategies alongside the react framework they serve: - mellea/stdlib/compaction.py -> mellea/stdlib/frameworks/react_compaction.py - test/stdlib/test_compaction.py -> test/stdlib/frameworks/test_react_compaction.py Imports and module docstrings updated accordingly. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/frameworks/react.py | 2 +- .../stdlib/{compaction.py => frameworks/react_compaction.py} | 2 +- .../test_react_compaction.py} | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename mellea/stdlib/{compaction.py => frameworks/react_compaction.py} (99%) rename test/stdlib/{test_compaction.py => frameworks/test_react_compaction.py} (99%) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 81dc04146..921781cc0 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -15,7 +15,7 @@ from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document -from mellea.stdlib.compaction import CompactionStrategy +from mellea.stdlib.frameworks.react_compaction import CompactionStrategy from mellea.stdlib.components.chat import ToolMessage from mellea.stdlib.components.react import ( MELLEA_FINALIZER_TOOL, diff --git a/mellea/stdlib/compaction.py b/mellea/stdlib/frameworks/react_compaction.py similarity index 99% rename from mellea/stdlib/compaction.py rename to mellea/stdlib/frameworks/react_compaction.py index 20b60f336..111b95524 100644 --- a/mellea/stdlib/compaction.py +++ b/mellea/stdlib/frameworks/react_compaction.py @@ -14,7 +14,7 @@ Example:: - from mellea.stdlib.compaction import KeepLastN + from mellea.stdlib.frameworks.react_compaction import KeepLastN from mellea.stdlib.frameworks.react import react # Compact once the most recent model call reports > 8000 prompt+completion tokens. diff --git a/test/stdlib/test_compaction.py b/test/stdlib/frameworks/test_react_compaction.py similarity index 99% rename from test/stdlib/test_compaction.py rename to test/stdlib/frameworks/test_react_compaction.py index 076faa7f6..07e5e44ce 100644 --- a/test/stdlib/test_compaction.py +++ b/test/stdlib/frameworks/test_react_compaction.py @@ -1,4 +1,4 @@ -"""Unit and integration tests for mellea.stdlib.compaction.""" +"""Unit and integration tests for mellea.stdlib.frameworks.react_compaction.""" from collections.abc import Sequence from dataclasses import dataclass @@ -16,7 +16,7 @@ ModelOutputThunk, ModelToolCall, ) -from mellea.stdlib.compaction import ( +from mellea.stdlib.frameworks.react_compaction import ( ClearAll, KeepLastN, LLMSummarize, From 16e7571a529169483afb82f44a3d61da0c8e8bda Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Wed, 6 May 2026 09:16:42 -0400 Subject: [PATCH 5/6] docs: add Args/Returns sections to react_compaction compact overrides MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The docstring quality gate (tooling/docs-autogen/audit_coverage.py --quality --threshold 100) requires each documented symbol to have its own Args/Returns sections — inheritance from the abstract parent is not consulted. Six issues were reported against the compact() overrides on ClearAll, KeepLastN, and LLMSummarize. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- mellea/stdlib/frameworks/react_compaction.py | 35 ++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/frameworks/react_compaction.py b/mellea/stdlib/frameworks/react_compaction.py index 111b95524..ccc312b5f 100644 --- a/mellea/stdlib/frameworks/react_compaction.py +++ b/mellea/stdlib/frameworks/react_compaction.py @@ -212,7 +212,16 @@ async def compact( backend: Backend | None = None, goal: str | None = None, ) -> ChatContext: - """Return a context containing only the prefix.""" + """Return a context containing only the prefix. + + Args: + context: The context to compact. + backend: Unused by this strategy; accepted for interface compatibility. + goal: Unused by this strategy; accepted for interface compatibility. + + Returns: + A new ``ChatContext`` containing only the prefix components. + """ components = context.as_list() prefix_end = _find_prefix_end(components) compacted = components[:prefix_end] @@ -245,7 +254,18 @@ async def compact( backend: Backend | None = None, goal: str | None = None, ) -> ChatContext: - """Return a context with the prefix and the last *keep_n* body components.""" + """Return a context with the prefix and the last *keep_n* body components. + + Args: + context: The context to compact. + backend: Unused by this strategy; accepted for interface compatibility. + goal: Unused by this strategy; accepted for interface compatibility. + + Returns: + A new ``ChatContext`` with the prefix plus the most recent *keep_n* + body components, or the original *context* if the body is already + at or below *keep_n* in length. + """ components = context.as_list() prefix_end = _find_prefix_end(components) prefix = components[:prefix_end] @@ -288,6 +308,17 @@ async def compact( ) -> ChatContext: """Return a context with the prefix, an LLM summary, and recent body components. + Args: + context: The context to compact. + backend: Backend used to generate the summary; required. + goal: The react goal string, included in the summary prompt; required. + + Returns: + A new ``ChatContext`` containing the prefix, a single summary + ``Message`` produced by the backend, and the most recent *keep_n* + body components verbatim. Returns the original *context* if the + body is already at or below *keep_n* in length. + Raises: ValueError: If *backend* or *goal* are not provided. """ From 030bcef7496a78f0836c28de94428db965a0ab51 Mon Sep 17 00:00:00 2001 From: Yousef El-Kurdi Date: Fri, 22 May 2026 14:33:25 +0000 Subject: [PATCH 6/6] feat(compaction): per-turn Compactor protocol for ChatContext + ReACT Replaces the original async ``react_compaction`` strategies (ClearAll, KeepLastN, LLMSummarize) with a generic, sync ``Compactor`` protocol that operates on any ``Context``. ``ReACT`` and ``ChatContext`` are rewired around the new protocol; sample callers, tests, and docs are updated. Squash of 29 Mellea-side commits from context_compaction_for_react_2; the BCP eval harness commits in that branch are intentionally excluded. mellea/stdlib/context/ becomes a package - Compactor protocol: sync ``compact(ctx, *, backend=None) -> Context`` - WindowCompactor(size, pin_predicate) keep last-N body components; ``size=0`` clears the body and retains only the pinned prefix - ThresholdCompactor(inner, threshold) token-gated wrapper that reads cumulative context size from the most recent ModelOutputThunk's ``generation.usage`` and forwards to ``inner.compact`` only above the gate - LLMSummarizeCompactor(keep_n, pin_predicate, prompt_template) summarizes old body components via the backend; the (async) backend call is hidden behind a sync ``compact()`` via ``_run_coro_blocking`` so the protocol stays sync - PinPredicate API: ``pin_nothing``, ``pin_system``, ``pin_system_and_initial_user``; chat compactors compose freely mellea/stdlib/frameworks/react.py - ``react()`` gains a ``compactor: Compactor | None = None`` per-turn hook; invoked once after each tool observation - The old ``react_compaction`` module is removed mellea/stdlib/components/react.py - ``pin_react_initiator``: a PinPredicate that pins everything up to and including the first ``ReactInitiator`` - ``react_summary_prompt(goal=None, max_tokens_hint=None)``: factory that returns a research-flavoured summary prompt template (with the {conversation} placeholder LLMSummarizeCompactor expects). Optional ``GOAL: `` line and optional ``- Be at most ~N tokens`` bullet when callers want goal anchoring or length-cap hints. mellea/stdlib/context/chat.py - ``ChatContext()`` defaults to no compactor (full history); pass ``compactor=`` or ``window_size=`` for opt-in compaction. Matches upstream main's window_size=None unbounded semantics. Test coverage - test/stdlib/test_compactor.py (~500 LOC): protocol semantics; Window / Threshold / LLMSummarize behaviours; pin-predicate edge cases; ``size=0`` collapse; threshold gate edge cases - test/stdlib/frameworks/test_react_framework.py (~210 LOC): react() per-turn hook integration + react_summary_prompt (default, goal interpolation, brace escaping, max_tokens_hint bullet ordering, LLMSummarizeCompactor template-validation) - test/stdlib/test_base_context.py: pin-non-compacting ChatContext in the session-copy operations test (matches new opt-in default) Net diff: 17 files, +381 / -896 lines (drops the old react_compaction.py and its dedicated test file). Backwards-compatible default behaviour preserved: bare ``ChatContext()`` retains full history; ``react()`` without ``compactor=`` behaves identically to today; ``LLMSummarizeCompactor`` defaults to a generic conversation-summary prompt unless callers opt in to the research-flavoured variant via ``react_summary_prompt``. Assisted-by: Claude Code Signed-off-by: Yousef El-Kurdi --- docs/examples/context/README.md | 51 +- docs/examples/context/custom_compactor.py | 63 +++ docs/examples/context/react_compaction.py | 235 +++++++++ docs/examples/context/threshold_compactor.py | 57 ++ docs/examples/context/window_compactor.py | 101 ++++ mellea/stdlib/components/react.py | 96 ++++ mellea/stdlib/context.py | 82 --- mellea/stdlib/context/__init__.py | 45 ++ mellea/stdlib/context/chat.py | 101 ++++ mellea/stdlib/context/compactor.py | 427 +++++++++++++++ mellea/stdlib/context/simple.py | 32 ++ mellea/stdlib/frameworks/react.py | 25 +- mellea/stdlib/frameworks/react_compaction.py | 397 -------------- .../frameworks/test_react_compaction.py | 395 -------------- .../stdlib/frameworks/test_react_framework.py | 212 ++++++++ test/stdlib/test_base_context.py | 19 +- test/stdlib/test_compactor.py | 492 ++++++++++++++++++ 17 files changed, 1934 insertions(+), 896 deletions(-) create mode 100644 docs/examples/context/custom_compactor.py create mode 100644 docs/examples/context/react_compaction.py create mode 100644 docs/examples/context/threshold_compactor.py create mode 100644 docs/examples/context/window_compactor.py delete mode 100644 mellea/stdlib/context.py create mode 100644 mellea/stdlib/context/__init__.py create mode 100644 mellea/stdlib/context/chat.py create mode 100644 mellea/stdlib/context/compactor.py create mode 100644 mellea/stdlib/context/simple.py delete mode 100644 mellea/stdlib/frameworks/react_compaction.py delete mode 100644 test/stdlib/frameworks/test_react_compaction.py create mode 100644 test/stdlib/test_compactor.py diff --git a/docs/examples/context/README.md b/docs/examples/context/README.md index dde027bc5..e7b8b3752 100644 --- a/docs/examples/context/README.md +++ b/docs/examples/context/README.md @@ -1,13 +1,15 @@ # Context Examples -This directory contains examples demonstrating how to work with Mellea's context system, particularly when using sampling strategies and validation. +This directory contains examples demonstrating how to work with Mellea's context system: inspecting per-attempt contexts produced by sampling strategies, and shrinking contexts with the `Compactor` protocol. ## Files ### contexts_with_sampling.py + Shows how to retrieve and inspect context information when using sampling strategies and validation. **Key Features:** + - Using `RejectionSamplingStrategy` with requirements - Accessing `SamplingResult` objects to inspect generation attempts - Retrieving context for different generation attempts @@ -15,10 +17,34 @@ Shows how to retrieve and inspect context information when using sampling strate - Understanding the context tree structure **Usage:** -```bash + +``` python docs/examples/context/contexts_with_sampling.py ``` +### window_compactor.py + +`WindowCompactor` — opt-in by passing `compactor=` (or the `window_size=` sugar). Demonstrates system-prefix pinning, `pin_system_and_initial_user`, `pin_nothing` (pure last-N), and `size=0` to clear the body. + +### threshold_compactor.py + +`ThresholdCompactor` — gate an inner compactor on the conversation's cumulative token size. The reading is taken from the most recent `ModelOutputThunk`'s `total_tokens`, which for a chat backend equals `prompt_tokens` (full conversation history sent to the model) + `completion_tokens` (reply). The gate fires once the running conversation size crosses the threshold; once compaction shrinks the context, the next call produces a smaller reading and the gate closes again. + +### custom_compactor.py + +Implement the `Compactor` protocol with a plain class (no inheritance). Shows Pattern 1 (wired into `ChatContext`) and Pattern 2 (manual `compact()` call). + +### react_compaction.py + +Compose the ReACT loop with a sync `Compactor`. Two integration points: + +- **Per-add** — wire a `Compactor` onto the `ChatContext` so it runs every time `react()` appends a Message, ToolMessage, or thunk. +- **Per-turn** — pass `compactor=` to `react()`; it fires once per ReACT iteration after the tool observation. + +`LLMSummarizeCompactor` is also a sync `Compactor` — it hides the async backend call internally (worker thread when called from an already-running event loop) so callers don't have to think about sync vs async. + +Use `pin_react_initiator` (from `mellea.stdlib.components.react`) as the predicate so the goal and tool registration survive compaction. + ## Concepts Demonstrated - **Sampling Results**: Working with `SamplingResult` objects @@ -26,6 +52,8 @@ python docs/examples/context/contexts_with_sampling.py - **Multiple Attempts**: Examining different generation attempts - **Context Trees**: Understanding how contexts link together - **Validation Context**: Inspecting how requirements were evaluated +- **Compaction Protocol**: Sync `Compactor` for per-`add()` shrinking +- **Pin Predicates**: Auto-protect leading system messages or the user's initial prompt during compaction ## Key APIs @@ -48,8 +76,23 @@ gen_ctx.previous_node.node_data val_ctx.node_data ``` +```python +# Wire a compactor into a ChatContext (Pattern 1 — runs on every add()) +from mellea.stdlib.context import ChatContext, WindowCompactor, ThresholdCompactor + +ctx = ChatContext(compactor=WindowCompactor(size=5)) # default: pin_system +ctx = ChatContext(window_size=5) # sugar for the line above +ctx = ChatContext( + compactor=ThresholdCompactor(WindowCompactor(size=5), threshold=8000), +) + +# Manual compaction (Pattern 2) +ctx = WindowCompactor(size=0).compact(ctx) # drop body, keep pinned prefix +``` + ## Related Documentation -- See `mellea/stdlib/context.py` for context implementation +- See `mellea/stdlib/context/` for context and compactor implementations - See `mellea/stdlib/sampling/` for sampling strategies -- See `docs/dev/spans.md` for context architecture details \ No newline at end of file +- See `mellea/stdlib/frameworks/react.py` for the ReACT loop +- See `docs/dev/spans.md` for context architecture details diff --git a/docs/examples/context/custom_compactor.py b/docs/examples/context/custom_compactor.py new file mode 100644 index 000000000..663f21b4a --- /dev/null +++ b/docs/examples/context/custom_compactor.py @@ -0,0 +1,63 @@ +# pytest: unit +"""Implementing the Compactor protocol — anything with ``compact()`` works. + +The protocol is structurally typed: a class with a ``compact(ctx, *, +backend=None) -> ChatContext`` method is a valid Compactor. No +inheritance is required. +""" + +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ChatContext, Compactor +from mellea.stdlib.context.chat import _rebuild_chat_context + + +class TruncateOldest: + """Drop only the very first body component each call. + + Demonstrates the smallest possible Compactor implementation. Pattern + 1 (wired into ``ChatContext``) means each ``add()`` removes the + oldest item then appends — net result: the context never grows. + """ + + def compact(self, ctx, *, backend=None): + items = ctx.as_list() + if len(items) <= 1: + return ctx + return _rebuild_chat_context(items[1:], compactor=ctx._compactor) + + +def pattern_1_wired_into_context(): + """Pattern 1: compactor lives on the context, runs in ``add()``.""" + ctx = ChatContext(compactor=TruncateOldest()) + for i in range(4): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3'] (oldest dropped before each append) + + +def pattern_2_manual_call(): + """Pattern 2: caller invokes ``compact()`` directly between turns.""" + ctx = ChatContext(window_size=10_000) # permissive — no auto-compaction + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + truncated = TruncateOldest().compact(ctx) + return [m.content for m in truncated.as_list()] + + +def structural_typing_check(): + """The Compactor protocol is satisfied structurally, no inheritance.""" + c: Compactor = TruncateOldest() # mypy-checked Protocol assignment + return type(c).__name__ + + +if __name__ == "__main__": + for fn in [pattern_1_wired_into_context, pattern_2_manual_call]: + print(f"--- {fn.__name__} ---") + print(fn()) + print(f"structural typing: {structural_typing_check()} satisfies Compactor") + + +def test_custom_compactor_examples(): + assert pattern_1_wired_into_context() == ["msg 3"] + assert pattern_2_manual_call() == ["msg 1", "msg 2", "msg 3", "msg 4"] + assert structural_typing_check() == "TruncateOldest" diff --git a/docs/examples/context/react_compaction.py b/docs/examples/context/react_compaction.py new file mode 100644 index 000000000..baa19c3dd --- /dev/null +++ b/docs/examples/context/react_compaction.py @@ -0,0 +1,235 @@ +# pytest: unit +"""Compose the ReACT loop with a sync `Compactor`. + +Two integration points are available, and they're complementary: + +1. **Per-add** — the `ChatContext`'s own compactor runs every time the + ReACT loop appends a Message, ToolMessage, or thunk. This is fine + for cheap strategies like `WindowCompactor`. +2. **Per-turn** — pass `compactor=` to ``react(...)`` to invoke a + compactor once per ReACT iteration after the tool observation. Use + it for heavier strategies that should fire at turn boundaries + instead of on every component append. + +In both cases use ``pin_react_initiator`` (from +``mellea.stdlib.components.react``) so the goal and tool registration +survive compaction. + +This example exercises the wiring end-to-end against a fake backend so +no LLM is required. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from dataclasses import dataclass + +from mellea.backends.tools import MelleaTool +from mellea.core.backend import Backend, BaseModelSubclass +from mellea.core.base import ( + C, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, + ModelToolCall, +) +from mellea.stdlib.components.react import ( + MELLEA_FINALIZER_TOOL, + ReactInitiator, + _mellea_finalize_tool, + pin_react_initiator, +) +from mellea.stdlib.context import ChatContext, WindowCompactor +from mellea.stdlib.frameworks.react import react + +# --------------------------------------------------------------------------- # +# Fake backend so the example runs without an LLM # +# --------------------------------------------------------------------------- # + + +@dataclass +class _ScriptedTurn: + value: str + tool_calls: dict[str, ModelToolCall] | None = None + + +class ScriptedBackend(Backend): + """Returns pre-scripted responses; no real model is called.""" + + def __init__(self, script: list[_ScriptedTurn]) -> None: + self._script = iter(script) + + async def _generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + turn = next(self._script) + mot: ModelOutputThunk = ModelOutputThunk( + value=turn.value, tool_calls=turn.tool_calls + ) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +def _tool(name: str, return_value: str = "ok") -> MelleaTool: + def _fn() -> str: + return return_value + + return MelleaTool.from_callable(_fn, name=name) + + +def _tool_call(tool_name: str, tool: MelleaTool, thought: str) -> _ScriptedTurn: + tc = ModelToolCall(name=tool_name, func=tool, args={}) + return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + + +def _final(answer: str) -> _ScriptedTurn: + finalizer = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) + tc = ModelToolCall( + name=MELLEA_FINALIZER_TOOL, func=finalizer, args={"answer": answer} + ) + return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) + + +# --------------------------------------------------------------------------- # +# Pattern A — per-add compaction wired into the ChatContext # +# --------------------------------------------------------------------------- # + + +async def per_add_compaction(): + """A `WindowCompactor(pin_react_initiator)` on the ChatContext compacts + on every ``add()`` — Messages, ToolMessages, thunks. The ReactInitiator + stays pinned across the whole loop. + """ + search = _tool("search") + backend = ScriptedBackend( + [ + _tool_call("search", search, "step 1"), + _tool_call("search", search, "step 2"), + _tool_call("search", search, "step 3"), + _final("done"), + ] + ) + ctx = ChatContext( + compactor=WindowCompactor(size=3, pin_predicate=pin_react_initiator) + ) + result, ctx = await react( + goal="find info", context=ctx, backend=backend, tools=[search], loop_budget=10 + ) + return ( + result.value, + any(isinstance(c, ReactInitiator) for c in ctx.as_list()), + len(ctx.as_list()), + ) + + +# --------------------------------------------------------------------------- # +# Pattern B — per-turn compaction passed to react() # +# --------------------------------------------------------------------------- # + + +async def per_turn_compaction(): + """Pass ``compactor=`` to ``react`` for once-per-turn invocation. + + Use a permissive ``ChatContext`` (large window) so the per-add path is + effectively disabled — only the per-turn hook drives compaction. + """ + search = _tool("search") + backend = ScriptedBackend( + [ + _tool_call("search", search, "step 1"), + _tool_call("search", search, "step 2"), + _tool_call("search", search, "step 3"), + _final("done"), + ] + ) + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator), + ) + return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list())) + + +# --------------------------------------------------------------------------- # +# Pattern C — LLM-driven summarisation # +# --------------------------------------------------------------------------- # + + +async def llm_summarize_compaction(): + """Wire :class:`LLMSummarizeCompactor` into ``react()``. + + ``LLMSummarizeCompactor`` implements the sync :class:`Compactor` + protocol — its ``compact`` method internally orchestrates the async + backend call (running it on a worker thread when invoked from inside + an event loop). From ``react()``'s perspective it's just another + sync compactor. + + To keep the scripted backend simple, this example sets ``keep_n`` + large enough that summarisation never fires (no LLM call is needed). + Real usage would pair it with ``ThresholdCompactor`` so it only + activates once the conversation crosses a token budget. See + ``TestLLMSummarizeCompactor`` in ``test/stdlib/test_compactor.py`` for + unit tests that exercise the actual summary path. + """ + from mellea.stdlib.context import LLMSummarizeCompactor + + search = _tool("search") + backend = ScriptedBackend([_tool_call("search", search, "step 1"), _final("done")]) + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + # keep_n=1000 → no summarisation triggers in this short script; + # the example just shows the async compactor is wired correctly. + compactor=LLMSummarizeCompactor(keep_n=1000, pin_predicate=pin_react_initiator), + ) + return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list())) + + +if __name__ == "__main__": + print(f"per_add_compaction: {asyncio.run(per_add_compaction())}") + print(f"per_turn_compaction: {asyncio.run(per_turn_compaction())}") + print(f"llm_summarize_compact: {asyncio.run(llm_summarize_compaction())}") + + +def test_per_add_compaction(): + answer, has_initiator, _length = asyncio.run(per_add_compaction()) + assert answer == "done" + assert has_initiator + + +def test_per_turn_compaction(): + answer, has_initiator = asyncio.run(per_turn_compaction()) + assert answer == "done" + assert has_initiator + + +def test_llm_summarize_compaction(): + answer, has_initiator = asyncio.run(llm_summarize_compaction()) + assert answer == "done" + assert has_initiator diff --git a/docs/examples/context/threshold_compactor.py b/docs/examples/context/threshold_compactor.py new file mode 100644 index 000000000..120eba07c --- /dev/null +++ b/docs/examples/context/threshold_compactor.py @@ -0,0 +1,57 @@ +# pytest: unit +"""ThresholdCompactor — gate an inner Compactor on conversation size. + +Reads ``ModelOutputThunk.generation.usage`` from the most recent thunk +in the context. For a chat backend, ``total_tokens`` on that thunk is +``prompt_tokens`` (full conversation history sent to the model) plus +``completion_tokens`` (the reply), so it tracks *cumulative* context +size — not just one call's isolated tokens. The inner compactor fires +once that running size exceeds the configured threshold. +""" + +from mellea.core.base import ModelOutputThunk +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ChatContext, ThresholdCompactor, WindowCompactor + + +def _thunk(total_tokens: int) -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict (test helper).""" + mot = ModelOutputThunk(value="") + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +def below_threshold_passthrough(): + """Token usage is below threshold → inner compactor is NOT invoked.""" + gated = ThresholdCompactor(WindowCompactor(size=2), threshold=1000) + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + ctx = ctx.add(_thunk(50)) # only 50 tokens — below 1000 + out = gated.compact(ctx) + return len(out.as_list()) # 6 (5 messages + thunk) — unchanged + + +def above_threshold_compacts(): + """Token usage exceeds threshold → inner compactor runs.""" + gated = ThresholdCompactor(WindowCompactor(size=2), threshold=1000) + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + ctx = ctx.add(_thunk(2000)) # 2000 tokens — over the gate + out = gated.compact(ctx) + return len(out.as_list()) # 2 — WindowCompactor(size=2) ran + + +if __name__ == "__main__": + print(f"below_threshold_passthrough: {below_threshold_passthrough()}") + print(f"above_threshold_compacts: {above_threshold_compacts()}") + + +def test_threshold_compactor_examples(): + assert below_threshold_passthrough() == 6 + assert above_threshold_compacts() == 2 diff --git a/docs/examples/context/window_compactor.py b/docs/examples/context/window_compactor.py new file mode 100644 index 000000000..320dd0e31 --- /dev/null +++ b/docs/examples/context/window_compactor.py @@ -0,0 +1,101 @@ +# pytest: unit +"""WindowCompactor — keep the last N body components. + +Demonstrates the default behaviour, the ``window_size=`` sugar on +``ChatContext``, and how the auto-pinned system prefix is preserved. +""" + +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ( + ChatContext, + WindowCompactor, + pin_nothing, + pin_system_and_initial_user, +) + + +def basic_window(): + """``ChatContext()`` keeps the full history by default; opt in via + ``compactor=`` to start truncating. + """ + ctx = ChatContext(compactor=WindowCompactor(size=5)) + for i in range(8): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3', 'msg 4', 'msg 5', 'msg 6', 'msg 7'] + + +def window_size_sugar(): + """``window_size=`` is sugar for ``WindowCompactor(size=...)``.""" + ctx = ChatContext(window_size=3) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3', 'msg 4', 'msg 5'] + + +def system_prefix_pinned(): + """Default predicate ``pin_system`` keeps a leading system message.""" + ctx = ChatContext(window_size=3) + ctx = ctx.add(Message("system", "You are a helpful assistant.")) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + # → [('system', '...'), ('user', 'msg 3'), ('user', 'msg 4'), ('user', 'msg 5')] + + +def pin_initial_user_too(): + """Use ``pin_system_and_initial_user`` to also keep the user's first turn.""" + ctx = ChatContext( + compactor=WindowCompactor(size=3, pin_predicate=pin_system_and_initial_user) + ) + ctx = ctx.add(Message("system", "You are helpful.")) + ctx = ctx.add(Message("user", "What is the capital of France?")) + for i in range(6): + ctx = ctx.add(Message("assistant", f"reply {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + + +def pure_last_n(): + """``pin_nothing`` disables prefix pinning — the system message is dropped.""" + ctx = ChatContext(compactor=WindowCompactor(size=3, pin_predicate=pin_nothing)) + ctx = ctx.add(Message("system", "ignored after a few turns")) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + + +def clear_body_keep_prefix(): + """``size=0`` drops the body entirely while keeping the pinned prefix.""" + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message("system", "You are helpful.")) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + cleared = WindowCompactor(size=0).compact(ctx) + return [(m.role, m.content) for m in cleared.as_list()] + # → [('system', 'You are helpful.')] + + +if __name__ == "__main__": + for fn in [ + basic_window, + window_size_sugar, + system_prefix_pinned, + pin_initial_user_too, + pure_last_n, + clear_body_keep_prefix, + ]: + print(f"--- {fn.__name__} ---") + print(fn()) + + +def test_window_compactor_examples(): + """Smoke test all examples — invariants documented in each docstring.""" + assert basic_window() == ["msg 3", "msg 4", "msg 5", "msg 6", "msg 7"] + assert window_size_sugar() == ["msg 3", "msg 4", "msg 5"] + assert system_prefix_pinned()[0] == ("system", "You are a helpful assistant.") + pinned = pin_initial_user_too() + assert pinned[0] == ("system", "You are helpful.") + assert pinned[1] == ("user", "What is the capital of France?") + assert all(role == "user" for role, _ in pure_last_n()) + assert clear_body_keep_prefix() == [("system", "You are helpful.")] diff --git a/mellea/stdlib/components/react.py b/mellea/stdlib/components/react.py index 8f08fe8a0..dfc423043 100644 --- a/mellea/stdlib/components/react.py +++ b/mellea/stdlib/components/react.py @@ -32,6 +32,102 @@ def _mellea_finalize_tool(answer: str) -> str: return answer +def pin_react_initiator(components: list[Component | CBlock]) -> int: + """A ``PinPredicate`` that pins everything up to and including the first ``ReactInitiator``. + + Plug it into any compactor in :mod:`mellea.stdlib.context` that takes a + ``pin_predicate`` (e.g. :class:`WindowCompactor`, + :class:`ThresholdCompactor`'s inner compactor) so the react goal and + tool registration survive compaction: + + from mellea.stdlib.context import ChatContext, WindowCompactor + from mellea.stdlib.components.react import pin_react_initiator + + ctx = ChatContext( + compactor=WindowCompactor(size=5, pin_predicate=pin_react_initiator), + ) + result, _ = await react(goal=..., context=ctx, ...) + + Returns ``0`` when no ``ReactInitiator`` is found, so a context that + has not yet been seeded with a react goal compacts as if there were + no prefix. + """ + for i, c in enumerate(components): + if isinstance(c, ReactInitiator): + return i + 1 + return 0 + + +def react_summary_prompt( + goal: str | None = None, + max_tokens_hint: int | None = None, +) -> str: + """Build a research-flavoured summary prompt for :class:`LLMSummarizeCompactor`. + + Returns a template with a ``{conversation}`` placeholder that + :class:`LLMSummarizeCompactor` fills in at compaction time. Pass the + react goal via ``goal=`` to anchor the summarisation around the + objective; with ``goal=None`` the ``GOAL:`` line is omitted. + + Pass ``max_tokens_hint=N`` to inject a soft length-cap bullet + ("Be at most ~N tokens") into the summarizer's instructions. The hint + is a plan-time anchor for the model — combine it with a hard + ``max_tokens`` API arg on the summarizer's LLM call to enforce. + ``max_tokens_hint=None`` (default) or non-positive values omit the + bullet, so the prompt is byte-identical to the un-hinted form. + + Curly braces in ``goal`` are escaped so :meth:`str.format` (used by the + compactor) preserves them as literal characters. + + Example:: + + from mellea.stdlib.components.react import ( + pin_react_initiator, + react_summary_prompt, + ) + from mellea.stdlib.context import LLMSummarizeCompactor + + compactor = LLMSummarizeCompactor( + keep_n=5, + pin_predicate=pin_react_initiator, + prompt_template=react_summary_prompt( + goal="find papers on X", + max_tokens_hint=2000, + ), + ) + """ + if goal is not None: + # Escape braces so .format() in the compactor keeps them literal. + safe_goal = goal.replace("{", "{{").replace("}", "}}") + goal_block = f"GOAL: {safe_goal}\n\n" + else: + goal_block = "" + if max_tokens_hint is not None and max_tokens_hint > 0: + # Rough heuristic: ~0.75 words per token for English research text. + words_estimate = int(max_tokens_hint * 0.75) + length_bullet = ( + f"- Be at most ~{max_tokens_hint} tokens (roughly " + f"{words_estimate} words). Prioritize density: drop redundant " + "or ancillary detail.\n" + ) + else: + length_bullet = "" + return ( + "You are summarizing research progress to maintain context " + "within token limits.\n\n" + f"{goal_block}" + "Provide a comprehensive summary of the research context below. " + "Your summary should:\n" + "- Preserve ALL specific facts, numbers, names, URLs, and search " + "queries found\n" + "- Note which tools were called and what results were obtained\n" + "- Highlight key findings and any dead ends encountered\n" + "- Be structured clearly so the research can continue seamlessly\n" + f"{length_bullet}" + "\nContext to summarize:\n{conversation}" + ) + + class ReactInitiator(Component[str]): """`ReactInitiator` is used at the start of the ReACT loop to prime the model. diff --git a/mellea/stdlib/context.py b/mellea/stdlib/context.py deleted file mode 100644 index f0e2e8b2f..000000000 --- a/mellea/stdlib/context.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Concrete ``Context`` implementations for common conversation patterns. - -Provides ``ChatContext``, which accumulates all turns in a sliding-window chat history -(configurable via ``window_size``), and ``SimpleContext``, in which each interaction -is treated as a stateless single-turn exchange (no prior history is passed to the -model). Import ``ChatContext`` for multi-turn conversations and ``SimpleContext`` when -you want each call to the model to be independent. -""" - -from __future__ import annotations - -# Leave unused `ContextTurn` import for import ergonomics. -from ..core import CBlock, Component, Context, ContextTurn - - -class ChatContext(Context): - """Initializes a chat context with unbounded window_size and is_chat=True by default. - - Args: - window_size (int | None): Maximum number of context turns to include when - calling ``view_for_generation``. ``None`` (the default) means the full - history is always returned. - """ - - def __init__(self, *, window_size: int | None = None): - """Initialize ChatContext with an optional sliding-window size.""" - super().__init__() - self._window_size = window_size - - def add(self, c: Component | CBlock) -> ChatContext: - """Add a new component or CBlock to the context and return the updated context. - - Args: - c (Component | CBlock): The component or content block to append. - - Returns: - ChatContext: A new ``ChatContext`` with the added entry, preserving the - current ``window_size`` setting. - """ - new = ChatContext.from_previous(self, c) - new._window_size = self._window_size - return new - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Return the context entries to pass to the model, respecting the configured window. - - Uses the ``window_size`` set during initialisation to limit how many past - turns are included. ``None`` is returned when the underlying history is - non-linear. - - Returns: - list[Component | CBlock] | None: Ordered list of context entries up to - ``window_size`` turns, or ``None`` if the history is non-linear. - """ - return self.as_list(self._window_size) - - -class SimpleContext(Context): - """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" - - def add(self, c: Component | CBlock) -> SimpleContext: - """Add a new component or CBlock to the context and return the updated context. - - Args: - c (Component | CBlock): The component or content block to record. - - Returns: - SimpleContext: A new ``SimpleContext`` containing only the added entry; - prior history is not retained. - """ - return SimpleContext.from_previous(self, c) - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Return an empty list, since ``SimpleContext`` does not pass history to the model. - - Each call to the model is treated as a stateless, independent exchange. - No prior turns are forwarded. - - Returns: - list[Component | CBlock] | None: Always an empty list. - """ - return [] diff --git a/mellea/stdlib/context/__init__.py b/mellea/stdlib/context/__init__.py new file mode 100644 index 000000000..60bf94d94 --- /dev/null +++ b/mellea/stdlib/context/__init__.py @@ -0,0 +1,45 @@ +"""Concrete ``Context`` implementations and the ``Compactor`` protocol. + +Provides: + +- :class:`ChatContext` — accumulates all turns in a chat history (with an + optional sliding window). +- :class:`SimpleContext` — stateless, single-turn exchange (no prior history is + passed to the model). +- :class:`Compactor` — generic protocol for shrinking any ``Context`` subtype. + +The names :class:`Context`, :class:`ContextTurn`, :class:`CBlock`, and +:class:`Component` are re-exported from :mod:`mellea.core` for the convenience +of callers that import them via ``mellea.stdlib.context``. +""" + +from mellea.core import CBlock, Component, Context, ContextTurn +from mellea.stdlib.context.chat import ChatContext +from mellea.stdlib.context.compactor import ( + Compactor, + LLMSummarizeCompactor, + PinPredicate, + ThresholdCompactor, + WindowCompactor, + pin_nothing, + pin_system, + pin_system_and_initial_user, +) +from mellea.stdlib.context.simple import SimpleContext + +__all__ = [ + "CBlock", + "ChatContext", + "Compactor", + "Component", + "Context", + "ContextTurn", + "LLMSummarizeCompactor", + "PinPredicate", + "SimpleContext", + "ThresholdCompactor", + "WindowCompactor", + "pin_nothing", + "pin_system", + "pin_system_and_initial_user", +] diff --git a/mellea/stdlib/context/chat.py b/mellea/stdlib/context/chat.py new file mode 100644 index 000000000..0ac548460 --- /dev/null +++ b/mellea/stdlib/context/chat.py @@ -0,0 +1,101 @@ +"""Chat-style context with pluggable compaction.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from mellea.core import CBlock, Component, Context + +if TYPE_CHECKING: + from mellea.stdlib.context.compactor import Compactor + + +class ChatContext(Context): + """Chat context that accumulates turns and optionally compacts on each ``add``. + + By default the context performs **no compaction** — the full history is + retained. Compaction is opt-in: pass ``compactor=`` for a custom + strategy, or ``window_size=`` as sugar for ``WindowCompactor(size=...)``. + + Args: + compactor (Compactor | None): The compactor invoked on every ``add``. + ``None`` (the default) means no compaction; full history is kept. + window_size (int | None): Sugar that constructs a + :class:`WindowCompactor`. Mutually exclusive with ``compactor``. + ``None`` (the default) means no windowing. + """ + + def __init__( + self, *, compactor: Compactor | None = None, window_size: int | None = None + ) -> None: + """Initialize a ChatContext with an optional compactor.""" + if compactor is not None and window_size is not None: + raise ValueError( + "ChatContext: pass either `compactor` or `window_size`, not both." + ) + super().__init__() + if compactor is None and window_size is not None: + from mellea.stdlib.context.compactor import WindowCompactor + + self._compactor: Compactor | None = cast( + "Compactor", WindowCompactor(size=window_size) + ) + else: + self._compactor = compactor + + def add(self, c: Component | CBlock) -> ChatContext: + """Append ``c`` and run the compactor; return the resulting context. + + Args: + c (Component | CBlock): The component or content block to append. + + Returns: + ChatContext: A new ``ChatContext`` carrying the same compactor. + """ + new = ChatContext.from_previous(self, c) + new._compactor = self._compactor + if self._compactor is not None: + new = self._compactor.compact(new) + return new + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Return the components to forward to the model. + + Compaction is now applied at ``add`` time (Pattern 1), so this just + returns the linear history. ``None`` is returned when the underlying + history is non-linear. + + Returns: + list[Component | CBlock] | None: Ordered list of context entries. + """ + return self.as_list() + + +def _rebuild_chat_context( + components: list[Component | CBlock], *, compactor: Compactor | None = None +) -> ChatContext: + """Build a fresh ``ChatContext`` linked-list without triggering compaction. + + Used by ``WindowCompactor`` (and any future compactors that need to rebuild + a chat history). Manual node construction sidesteps ``ChatContext.add`` so + compactors don't recurse during their own work. + + Args: + components: Components to materialise as the new context, in order. + compactor: Compactor to attach to every node of the rebuilt context. + + Returns: + A new ``ChatContext`` whose linear history is exactly ``components``. + """ + ctx: ChatContext = ChatContext.__new__(ChatContext) + Context.__init__(ctx) + ctx._compactor = compactor + for c in components: + new: ChatContext = ChatContext.__new__(ChatContext) + new._previous = ctx + new._data = c + new._is_root = False + new._is_chat_context = ctx._is_chat_context + new._compactor = compactor + ctx = new + return ctx diff --git a/mellea/stdlib/context/compactor.py b/mellea/stdlib/context/compactor.py new file mode 100644 index 000000000..409bebaca --- /dev/null +++ b/mellea/stdlib/context/compactor.py @@ -0,0 +1,427 @@ +"""Generic ``Compactor`` protocol for shrinking a ``Context``. + +A ``Compactor`` returns a fresh, compacted copy of a context. Implementations +must never mutate the input — by convention, every alteration must produce a +new ``Context`` instance (the base class enforces this via ``from_previous``). + +Two usage patterns are supported: + +- **Pattern 1 (in ``Context.add``):** A subclass of ``Context`` holds a + ``Compactor`` and applies it whenever a new component is appended. +- **Pattern 2 (manual):** The caller invokes ``compactor.compact(ctx)`` + directly between turns, e.g. when compaction is exposed to the model as a + tool. + +See ``docs/rewrite/`` for full usage examples. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar + +from mellea.core import CBlock, Component, Context, ModelOutputThunk +from mellea.core.backend import Backend + +if TYPE_CHECKING: + from mellea.stdlib.context.chat import ChatContext + +T = TypeVar("T", bound=Context) + + +# --------------------------------------------------------------------------- # +# Pin predicates # +# --------------------------------------------------------------------------- # + +PinPredicate: TypeAlias = Callable[[list[Component | CBlock]], int] +"""A function that returns the index after the pinned prefix. + +Given the full ordered list of context components, a ``PinPredicate`` +returns the integer index ``idx`` such that ``components[:idx]`` is the +pinned prefix that the compactor must preserve, and ``components[idx:]`` +is the body that compaction acts on. + +The shape subsumes both "contiguous role-based prefix" (e.g. +:func:`pin_system`) and "find the first marker component" styles. +""" + + +def pin_nothing(components: list[Component | CBlock]) -> int: + """A :class:`PinPredicate` that pins nothing — pure body, no protected prefix.""" + return 0 + + +def pin_system(components: list[Component | CBlock]) -> int: + """Pin contiguous leading ``Message(role="system")`` components. + + Stops at the first non-system component. A system message that appears + later in the conversation is *not* pinned. + """ + from mellea.stdlib.components.chat import Message + + i = 0 + while i < len(components): + c = components[i] + if isinstance(c, Message) and c.role == "system": + i += 1 + else: + break + return i + + +def pin_system_and_initial_user(components: list[Component | CBlock]) -> int: + """Pin leading system messages PLUS the first user message that follows. + + Useful when the initial user prompt encodes the goal of the conversation + and should survive compaction along with any system instructions. + """ + from mellea.stdlib.components.chat import Message + + i = pin_system(components) + if i < len(components): + c = components[i] + if isinstance(c, Message) and c.role == "user": + i += 1 + return i + + +def _last_usage_tokens(ctx: Context) -> int | None: + """Return cumulative token count of the conversation as of the most recent turn. + + Walks ``ctx`` back-to-front looking for a ``ModelOutputThunk`` whose + ``generation.usage`` dict has been populated by a backend's + ``post_processing``. Returns ``total_tokens`` from that thunk — which, + for a chat backend, is ``prompt_tokens`` (size of the full conversation + sent to the model) plus ``completion_tokens`` (the model's reply). It + is therefore an estimate of the *current* conversation size, not just + one call's tokens in isolation. + + Falls back to ``prompt_tokens + completion_tokens`` when ``total_tokens`` + is missing. Returns ``None`` if no usable token count can be recovered + (typical before the first model call completes). + """ + for c in reversed(ctx.as_list()): + if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: + usage = c.generation.usage + total = usage.get("total_tokens") + if total is None: + pt = usage.get("prompt_tokens") or 0 + ct = usage.get("completion_tokens") or 0 + total = pt + ct + return total if total and total > 0 else None + return None + + +class Compactor(Protocol): + """Protocol for objects that compact a ``Context`` into a smaller copy. + + A compactor receives a context and returns a new context that retains only + the data the strategy considers worth keeping. Implementations MUST NOT + mutate the input context; they must return a fresh instance and copy over + any data that should be preserved. + + The protocol is generic in ``T`` (a ``Context`` subtype) so concrete + compactors can narrow their input/output type — for example a chat-only + compactor declares ``T = ChatContext``. + + The protocol is sync. Compactors that need to perform a backend call + (e.g. :class:`LLMSummarizeCompactor`) hide the async work behind the sync + method internally — see that class for the strategy used. + """ + + def compact(self, ctx: T, *, backend: Backend | None = None) -> T: + """Return a compacted copy of ``ctx``. + + Args: + ctx: The context to compact. Must be left unchanged. + backend: Optional backend. Generic compactors that only filter + components can ignore it. + + Returns: + A new context of the same type as ``ctx`` containing only the + retained data. + """ + ... + + +class WindowCompactor: + """Retains the last ``size`` body components of a ``ChatContext``. + + Uses ``pin_predicate`` to decide which leading components to preserve as + a protected prefix; the size limit is then applied to the body that + remains. The total context length after compaction is + ``len(prefix) + min(size, body_len)``. ``size`` counts only body + components. + + When the body is already at or below ``size``, ``ctx`` is returned + unchanged so the original linked-list and ``previous_node`` chain are + preserved. The result carries the same ``Compactor`` as the input so + subsequent ``add()`` calls keep compacting. + + Args: + size (int): Maximum number of most-recent body components to retain. + Pinned prefix components do NOT count against this budget. + ``size=0`` is a special case that drops the body entirely, + keeping only the pinned prefix. Negative values raise + :class:`ValueError`. + pin_predicate (PinPredicate): Function that decides the prefix + boundary. Defaults to :func:`pin_system`, which pins contiguous + leading ``Message(role="system")`` components. Pass + :func:`pin_nothing` for pure last-N behaviour or any other + ``PinPredicate`` (e.g. :func:`pin_system_and_initial_user`). + """ + + def __init__(self, *, size: int, pin_predicate: PinPredicate = pin_system) -> None: + """Initialize with the desired body window size and a pin predicate.""" + if size < 0: + raise ValueError("WindowCompactor size must be non-negative") + self.size = size + self.pin_predicate = pin_predicate + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Return a copy of ``ctx`` truncated to the last ``size`` body components. + + Args: + ctx: The chat context to compact. + backend: Unused by this strategy; accepted for protocol compatibility. + + Returns: + A new ``ChatContext`` whose history is the pinned prefix plus the + last ``size`` body components, carrying ``ctx``'s compactor. + Returns ``ctx`` itself if no truncation is required. + """ + full = ctx.as_list() + pin_end = self.pin_predicate(full) + body_len = len(full) - pin_end + + if body_len <= self.size: + return ctx + + from mellea.stdlib.context.chat import _rebuild_chat_context + + keep_body = full[pin_end:][-self.size :] if self.size > 0 else [] + compacted = full[:pin_end] + keep_body + return _rebuild_chat_context(compacted, compactor=ctx._compactor) + + +class ThresholdCompactor: + """Wraps an inner ``Compactor``, gating it on the conversation's token size. + + Despite the suffix, this class does not compact directly — it forwards + to ``inner.compact`` only when the conversation has grown larger than + ``threshold`` tokens; otherwise the input is returned unchanged. + + The token measurement is read off the most recent ``ModelOutputThunk``'s + ``generation.usage`` (via :func:`_last_usage_tokens`). Because chat + backends report ``prompt_tokens`` as the size of the full history they + were given as input, ``total_tokens = prompt_tokens + completion_tokens`` + on the latest thunk effectively measures *the size of the conversation + after that turn*, not just one isolated call. So the gate fires once + cumulative context size crosses ``threshold``. + + Caveats: + + - Components appended *after* the last thunk (e.g. a tool response in + the same turn) are not yet reflected in the reading — there is a + one-turn lag, negligible unless a single tool call adds a very large + payload. + - When the inner compactor shrinks the context, the *next* model call + will produce a smaller ``prompt_tokens``, so the gate will close + again. The threshold is not a high-water mark. + - Returns the input unchanged if no thunk with usage is found yet + (typical before the first model call completes). + + Args: + inner (Compactor): The compactor to invoke once the threshold is + exceeded. + threshold (int): Trigger the inner compactor when the conversation's + measured token size (most recent thunk's ``total_tokens``) + exceeds this value. ``0`` or negative disables the gate (the + inner is never invoked). + """ + + def __init__(self, inner: Compactor, *, threshold: int) -> None: + """Initialize with the inner compactor and token threshold.""" + self.inner = inner + self.threshold = threshold + + def compact(self, ctx: T, *, backend: Backend | None = None) -> T: + """Forward to ``inner.compact`` only when ``ctx`` exceeds the threshold. + + Args: + ctx: The context to potentially compact. + backend: Forwarded to the inner compactor. + + Returns: + ``inner.compact(ctx, backend=backend)`` when the recovered token + count exceeds ``self.threshold``, otherwise ``ctx`` unchanged. + """ + if self.threshold <= 0: + return ctx + tokens = _last_usage_tokens(ctx) + if tokens is None or tokens <= self.threshold: + return ctx + return self.inner.compact(ctx, backend=backend) + + +_DEFAULT_SUMMARY_PROMPT = ( + "You are summarizing a conversation to maintain context within token " + "limits.\n\n" + "Provide a concise summary that:\n" + "- Preserves specific facts, numbers, names, URLs, and key data\n" + "- Notes which tools were called and what results were obtained\n" + "- Highlights key decisions, findings, and unresolved issues\n" + "- Is structured clearly so the conversation can continue seamlessly\n\n" + "Conversation to summarize:\n{conversation}" +) + + +def _run_coro_blocking(coro): # type: ignore[no-untyped-def] + """Run an awaitable to completion regardless of the calling context. + + - Outside any event loop: ``asyncio.run(coro)``. + - Inside a running event loop: spawn a worker thread that runs a fresh + event loop with ``asyncio.run`` and block until it returns. + + Used by sync compactors that need to call async backend code (e.g. + :class:`LLMSummarizeCompactor`). Note that the second branch blocks the + calling thread (and, transitively, the running event loop) for the + duration of the coroutine — fine for a serial loop like ReACT, but not + suitable if other tasks need to make progress concurrently. + """ + import asyncio + import concurrent.futures + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + + +class LLMSummarizeCompactor: + """Replace old body components with an LLM-generated summary, keep last ``keep_n`` verbatim. + + Implements the sync :class:`Compactor` protocol. The compactor's body + needs to call the (async) backend; that async work is hidden inside the + sync ``compact`` method via :func:`_run_coro_blocking`. The pinned + prefix (chosen by ``pin_predicate``) is preserved unchanged; body + components older than the last ``keep_n`` are flattened into a single + ``Message(role="user")`` whose content is a structured summary; the + last ``keep_n`` body components are kept verbatim. + + Default ``pin_predicate`` is :func:`pin_nothing`, which means the entire + conversation participates in summarisation. For react workflows pass + :func:`mellea.stdlib.components.react.pin_react_initiator` so the goal + and tool registration survive untouched. + + Args: + keep_n (int): Number of recent body components to keep verbatim. + ``0`` summarises everything below the prefix. + pin_predicate (PinPredicate): Function that decides the prefix + boundary. Defaults to :func:`pin_nothing`. + prompt_template (str | None): Custom summary prompt. Must contain + the literal ``{conversation}`` placeholder, which is filled in + with a textual rendering of the body to summarise. Defaults to + a generic conversation-summary template. + """ + + def __init__( + self, + *, + keep_n: int = 5, + pin_predicate: PinPredicate = pin_nothing, + prompt_template: str | None = None, + ) -> None: + """Initialize with the recent-body window, pin predicate, and prompt.""" + if keep_n < 0: + raise ValueError("LLMSummarizeCompactor keep_n must be non-negative") + template = ( + prompt_template if prompt_template is not None else _DEFAULT_SUMMARY_PROMPT + ) + if "{conversation}" not in template: + raise ValueError( + "LLMSummarizeCompactor prompt_template must contain '{conversation}'" + ) + self.keep_n = keep_n + self.pin_predicate = pin_predicate + self.prompt_template = template + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Return a context with the prefix, an LLM summary, and recent body components. + + Args: + ctx: The chat context to compact. + backend: Backend used to generate the summary; required. + + Returns: + A new ``ChatContext`` containing the prefix, a single summary + ``Message`` produced by the backend, and the most-recent + ``keep_n`` body components verbatim. Returns ``ctx`` unchanged + when the body is already at or below ``keep_n`` in length. + + Raises: + ValueError: If ``backend`` is not provided. + """ + if backend is None: + raise ValueError("LLMSummarizeCompactor requires a `backend`") + + full = ctx.as_list() + pin_end = self.pin_predicate(full) + body = full[pin_end:] + if len(body) <= self.keep_n: + return ctx + + return _run_coro_blocking(self._async_compact(ctx, backend)) + + async def _async_compact(self, ctx: ChatContext, backend: Backend) -> ChatContext: + """Async core — renders the body, calls the backend, rebuilds the context.""" + # Lazy imports to keep this module free of mellea.stdlib.components dependencies. + from mellea.stdlib import functional as mfuncs + from mellea.stdlib.components.chat import Message, ToolMessage + from mellea.stdlib.context.chat import _rebuild_chat_context + from mellea.stdlib.context.simple import SimpleContext + + full = ctx.as_list() + pin_end = self.pin_predicate(full) + prefix = full[:pin_end] + body = full[pin_end:] + + old = body[: -self.keep_n] if self.keep_n > 0 else body + recent = body[-self.keep_n :] if self.keep_n > 0 else [] + + # Render `old` to text the LLM can consume. + lines: list[str] = [] + for c in old: + if isinstance(c, ToolMessage): + lines.append(f"tool ({c.name}): {c.content}") + elif isinstance(c, Message): + lines.append(f"{c.role}: {c.content}") + elif isinstance(c, ModelOutputThunk): + lines.append(f"assistant: {c.value}") + elif isinstance(c, CBlock): + lines.append(str(c)) + else: + lines.append(str(getattr(c, "content", c))) + + prompt = self.prompt_template.format(conversation="\n".join(lines)) + result, _ = await mfuncs.aact( + action=Message(role="user", content=prompt), + context=SimpleContext(), + backend=backend, + requirements=[], + strategy=None, + await_result=True, + ) + + summary_message = Message( + role="user", content=f"[CONTEXT SUMMARY]\n{result.value or ''}" + ) + compacted = [*prefix, summary_message, *recent] + return _rebuild_chat_context(compacted, compactor=ctx._compactor) diff --git a/mellea/stdlib/context/simple.py b/mellea/stdlib/context/simple.py new file mode 100644 index 000000000..81f3cfb23 --- /dev/null +++ b/mellea/stdlib/context/simple.py @@ -0,0 +1,32 @@ +"""Stateless single-turn context (no history is forwarded to the model).""" + +from __future__ import annotations + +from mellea.core import CBlock, Component, Context + + +class SimpleContext(Context): + """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" + + def add(self, c: Component | CBlock) -> SimpleContext: + """Add a new component or CBlock to the context and return the updated context. + + Args: + c (Component | CBlock): The component or content block to record. + + Returns: + SimpleContext: A new ``SimpleContext`` containing only the added entry; + prior history is not retained. + """ + return SimpleContext.from_previous(self, c) + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Return an empty list, since ``SimpleContext`` does not pass history to the model. + + Each call to the model is treated as a stateless, independent exchange. + No prior turns are forwarded. + + Returns: + list[Component | CBlock] | None: Always an empty list. + """ + return [] diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index d65574a77..fcbc86db7 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -15,14 +15,13 @@ from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document -from mellea.stdlib.frameworks.react_compaction import CompactionStrategy from mellea.stdlib.components.chat import ToolMessage from mellea.stdlib.components.react import ( MELLEA_FINALIZER_TOOL, ReactInitiator, ReactThought, ) -from mellea.stdlib.context import ChatContext +from mellea.stdlib.context import ChatContext, Compactor async def react( @@ -37,7 +36,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, - compaction: CompactionStrategy | None = None, + compactor: Compactor | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -49,9 +48,14 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited - compaction: an optional ``CompactionStrategy`` to apply when the context - exceeds the strategy's configured threshold - (e.g. ``KeepLastN(keep_n=5, threshold=20)``). + compactor: optional sync ``Compactor`` invoked once per turn after the + tool observation. Use this for strategies that should fire at turn + boundaries rather than on every component append (per-add + compaction is configured on ``context`` itself). Compose with + :func:`mellea.stdlib.components.react.pin_react_initiator` to + preserve the goal across compactions. Compactors that need to + call the backend (e.g. ``LLMSummarizeCompactor``) hide the async + work behind their sync ``compact`` method internally. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -84,7 +88,6 @@ async def react( turn_num = 0 while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 - MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( @@ -135,10 +138,8 @@ async def react( step._underlying_value = str(tool_responses[0].content) return step, context - # Compact after the final-answer check so terminal turns skip it. - if compaction is not None: - context = await compaction.maybe_compact( - context, backend=backend, goal=goal - ) + # Per-turn compaction hook (terminal turns skip this since `is_final` returned). + if compactor is not None: + context = compactor.compact(context, backend=backend) raise RuntimeError(f"could not complete react loop in {loop_budget} iterations") diff --git a/mellea/stdlib/frameworks/react_compaction.py b/mellea/stdlib/frameworks/react_compaction.py deleted file mode 100644 index ccc312b5f..000000000 --- a/mellea/stdlib/frameworks/react_compaction.py +++ /dev/null @@ -1,397 +0,0 @@ -"""Context compaction strategies for the ReACT framework. - -Provides modular, callable strategy objects to compact a ``ChatContext`` that -has grown too large during a react loop. Three strategies are available: - -- ``ClearAll`` — discard the entire conversation body, keeping only the prefix - (everything up to and including the ``ReactInitiator``). -- ``KeepLastN`` — keep the prefix plus the *n* most recent body components. -- ``LLMSummarize`` — ask the backend to summarize old body components into a - single ``Message``, then keep the last *n* body components verbatim. - -All strategies preserve the **prefix** (every component up to and including the -first ``ReactInitiator``) so the model retains its goal and tool definitions. - -Example:: - - from mellea.stdlib.frameworks.react_compaction import KeepLastN - from mellea.stdlib.frameworks.react import react - - # Compact once the most recent model call reports > 8000 prompt+completion tokens. - await react( - goal="...", - context=ChatContext(), - backend=m.backend, - tools=[search_tool], - compaction=KeepLastN(keep_n=5, threshold=8000), - ) -""" - -from __future__ import annotations - -import abc - -from mellea.core.backend import Backend -from mellea.core.base import CBlock, Component, ModelOutputThunk -from mellea.core.utils import MelleaLogger -from mellea.stdlib.components.chat import Message, ToolMessage -from mellea.stdlib.components.react import ReactInitiator -from mellea.stdlib.context import ChatContext - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def rebuild_chat_context( - components: list[Component | CBlock], *, window_size: int | None = None -) -> ChatContext: - """Build a fresh ``ChatContext`` from an ordered list of components. - - Args: - components: Components to add, in chronological order. - window_size: Optional sliding-window size for the new context. - - Returns: - A new ``ChatContext`` containing all *components*. - """ - ctx = ChatContext(window_size=window_size) - for c in components: - ctx = ctx.add(c) - return ctx - - -def _find_prefix_end(components: list[Component | CBlock]) -> int: - """Return the index *after* the first ``ReactInitiator``. - - Everything in ``components[:idx]`` is the prefix that must be preserved by - every compaction strategy. Returns 0 when no ``ReactInitiator`` is found. - """ - for i, c in enumerate(components): - if isinstance(c, ReactInitiator): - return i + 1 - return 0 - - -def _last_usage_tokens(context: ChatContext) -> int | None: - """Return ``total_tokens`` from the most recent ``ModelOutputThunk`` with usage. - - Walks *context* back-to-front looking for a ``ModelOutputThunk`` whose - ``usage`` dict has been populated by a backend's ``post_processing``. - Falls back to ``prompt_tokens + completion_tokens`` when ``total_tokens`` - is missing. Returns ``None`` if no usable token count can be recovered — - typically the case before the first model call completes. - """ - for c in reversed(context.as_list()): - if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: - total = c.generation.usage.get("total_tokens") - if total is None: - pt = c.generation.usage.get("prompt_tokens") or 0 - ct = c.generation.usage.get("completion_tokens") or 0 - total = pt + ct - return total if total and total > 0 else None - return None - - -# --------------------------------------------------------------------------- -# Abstract base -# --------------------------------------------------------------------------- - - -class CompactionStrategy(abc.ABC): - """Abstract base class for context compaction strategies. - - Each strategy carries a ``threshold`` — the token count above which - compaction should fire. The :meth:`should_compact` helper reads the - most recent ``ModelOutputThunk.usage`` populated by the backend and - compares its total token count to ``threshold``. - - Because ``usage`` is recorded when a model call completes, the measured - token count reflects the context as of the *previous* turn — any - components appended since (e.g. a tool response) are not yet included. - In practice this one-turn lag is negligible unless a single tool call - adds a very large payload. - - Subclasses implement :meth:`compact` which receives the current - ``ChatContext`` and returns a compacted copy. The method is ``async`` - so that strategies requiring LLM calls (e.g. ``LLMSummarize``) work - transparently; synchronous strategies simply never ``await``. - - Args: - threshold (int): Trigger compaction when the most recent thunk's - total token usage exceeds this value. ``0`` disables compaction. - """ - - def __init__(self, *, threshold: int = 0) -> None: - """Initialize with the token-count threshold.""" - self.threshold = threshold - - def should_compact(self, context: ChatContext) -> bool: - """Return ``True`` when the last thunk's token usage exceeds ``threshold``. - - Reads ``total_tokens`` from the most recent ``ModelOutputThunk.usage`` - in *context*. Returns ``False`` when no thunk with usage is present - (e.g. before the first model call) or when ``threshold`` is not - positive. - - Args: - context: The context to check. - - Returns: - ``True`` if the recovered token count exceeds ``self.threshold`` - and ``self.threshold`` is greater than 0. - """ - if self.threshold <= 0: - return False - tokens = _last_usage_tokens(context) - if tokens is None: - return False - return tokens > self.threshold - - async def maybe_compact( - self, - context: ChatContext, - *, - backend: Backend | None = None, - goal: str | None = None, - ) -> ChatContext: - """Compact *context* only if it exceeds the threshold, otherwise return it unchanged. - - Args: - context: The context to check and potentially compact. - backend: The backend (forwarded to :meth:`compact`). - goal: The react goal string (forwarded to :meth:`compact`). - - Returns: - A compacted ``ChatContext`` if the threshold was exceeded, - or the original *context* unchanged. - """ - if self.should_compact(context): - return await self.compact(context, backend=backend, goal=goal) - return context - - @abc.abstractmethod - async def compact( - self, - context: ChatContext, - *, - backend: Backend | None = None, - goal: str | None = None, - ) -> ChatContext: - """Return a compacted copy of *context*. - - Args: - context: The context to compact. - backend: The backend (required by ``LLMSummarize``). - goal: The react goal string (required by ``LLMSummarize``). - - Returns: - A new, compacted ``ChatContext``. - """ - - -# --------------------------------------------------------------------------- -# Concrete strategies -# --------------------------------------------------------------------------- - - -class ClearAll(CompactionStrategy): - """Discard the entire conversation body, keeping only the prefix. - - The prefix is everything up to and including the first ``ReactInitiator``. - - Args: - threshold (int): Trigger compaction when the most recent thunk's total - token usage exceeds this value. - """ - - async def compact( - self, - context: ChatContext, - *, - backend: Backend | None = None, - goal: str | None = None, - ) -> ChatContext: - """Return a context containing only the prefix. - - Args: - context: The context to compact. - backend: Unused by this strategy; accepted for interface compatibility. - goal: Unused by this strategy; accepted for interface compatibility. - - Returns: - A new ``ChatContext`` containing only the prefix components. - """ - components = context.as_list() - prefix_end = _find_prefix_end(components) - compacted = components[:prefix_end] - - MelleaLogger.get_logger().info( - f"ClearAll: compacted context from {len(components)} to " - f"{len(compacted)} components" - ) - return rebuild_chat_context(compacted, window_size=context._window_size) - - -class KeepLastN(CompactionStrategy): - """Keep the prefix plus the last *keep_n* body components. - - Args: - keep_n (int): Number of recent body components to retain. - threshold (int): Trigger compaction when the most recent thunk's total - token usage exceeds this value. - """ - - def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: - """Initialize with the number of recent body components to keep.""" - super().__init__(threshold=threshold) - self.keep_n = keep_n - - async def compact( - self, - context: ChatContext, - *, - backend: Backend | None = None, - goal: str | None = None, - ) -> ChatContext: - """Return a context with the prefix and the last *keep_n* body components. - - Args: - context: The context to compact. - backend: Unused by this strategy; accepted for interface compatibility. - goal: Unused by this strategy; accepted for interface compatibility. - - Returns: - A new ``ChatContext`` with the prefix plus the most recent *keep_n* - body components, or the original *context* if the body is already - at or below *keep_n* in length. - """ - components = context.as_list() - prefix_end = _find_prefix_end(components) - prefix = components[:prefix_end] - body = components[prefix_end:] - - if len(body) <= self.keep_n: - return context # nothing to compact - - compacted = prefix + body[-self.keep_n :] - - MelleaLogger.get_logger().info( - f"KeepLastN(keep_n={self.keep_n}): compacted context from " - f"{len(components)} to {len(compacted)} components" - ) - return rebuild_chat_context(compacted, window_size=context._window_size) - - -class LLMSummarize(CompactionStrategy): - """Summarize old body components with the LLM, keep last *keep_n* verbatim. - - Requires ``backend`` and ``goal`` to be passed to :meth:`compact`. - - Args: - keep_n (int): Number of recent body components to retain verbatim. - threshold (int): Trigger compaction when the most recent thunk's total - token usage exceeds this value. - """ - - def __init__(self, *, keep_n: int = 5, threshold: int = 0) -> None: - """Initialize with the number of recent body components to keep.""" - super().__init__(threshold=threshold) - self.keep_n = keep_n - - async def compact( - self, - context: ChatContext, - *, - backend: Backend | None = None, - goal: str | None = None, - ) -> ChatContext: - """Return a context with the prefix, an LLM summary, and recent body components. - - Args: - context: The context to compact. - backend: Backend used to generate the summary; required. - goal: The react goal string, included in the summary prompt; required. - - Returns: - A new ``ChatContext`` containing the prefix, a single summary - ``Message`` produced by the backend, and the most recent *keep_n* - body components verbatim. Returns the original *context* if the - body is already at or below *keep_n* in length. - - Raises: - ValueError: If *backend* or *goal* are not provided. - """ - if backend is None or goal is None: - raise ValueError( - "LLMSummarize requires both 'backend' and 'goal' arguments" - ) - - from mellea.stdlib import functional as mfuncs - from mellea.stdlib.context import SimpleContext - - components = context.as_list() - prefix_end = _find_prefix_end(components) - prefix = components[:prefix_end] - body = components[prefix_end:] - - if len(body) <= self.keep_n: - return context # nothing to compact - - old = body[: -self.keep_n] if self.keep_n > 0 else body - recent = body[-self.keep_n :] if self.keep_n > 0 else [] - - # Build a textual representation of old components for summarization. - context_lines: list[str] = [] - for c in old: - if isinstance(c, ToolMessage): - context_lines.append(f"tool ({c.name}): {c.content}") - elif isinstance(c, Message): - context_lines.append(f"{c.role}: {c.content}") - elif isinstance(c, ModelOutputThunk): - context_lines.append(f"assistant: {c.value}") - elif isinstance(c, CBlock): - context_lines.append(str(c)) - else: - context_lines.append(str(getattr(c, "content", c))) - - summary_prompt = ( - "You are summarizing research progress to maintain context " - "within token limits.\n\n" - f"GOAL: {goal}\n\n" - "Provide a comprehensive summary of the research context below. " - "Your summary should:\n" - "- Preserve ALL specific facts, numbers, names, URLs, and search " - "queries found\n" - "- Note which tools were called and what results were obtained\n" - "- Highlight key findings and any dead ends encountered\n" - "- Be structured clearly so the research can continue seamlessly" - "\n\nContext to summarize:\n" - f"{chr(10).join(context_lines)}" - ) - - summary_action = Message(role="user", content=summary_prompt) - result, _ = await mfuncs.aact( - action=summary_action, - context=SimpleContext(), - backend=backend, - requirements=[], - strategy=None, - await_result=True, - ) - - summary_text = result.value or "" - summary_message = Message( - role="user", - content=( - f"[CONTEXT SUMMARY]\n{summary_text}\n\nContinue working on: {goal}" - ), - ) - - compacted = [*prefix, summary_message, *recent] - - MelleaLogger.get_logger().info( - f"LLMSummarize(keep_n={self.keep_n}): compacted context from " - f"{len(components)} to {len(compacted)} components" - ) - return rebuild_chat_context(compacted, window_size=context._window_size) diff --git a/test/stdlib/frameworks/test_react_compaction.py b/test/stdlib/frameworks/test_react_compaction.py deleted file mode 100644 index 07e5e44ce..000000000 --- a/test/stdlib/frameworks/test_react_compaction.py +++ /dev/null @@ -1,395 +0,0 @@ -"""Unit and integration tests for mellea.stdlib.frameworks.react_compaction.""" - -from collections.abc import Sequence -from dataclasses import dataclass - -import pytest - -from mellea.backends.tools import MelleaTool -from mellea.core.backend import Backend, BaseModelSubclass -from mellea.core.base import ( - C, - CBlock, - Component, - Context, - GenerateLog, - ModelOutputThunk, - ModelToolCall, -) -from mellea.stdlib.frameworks.react_compaction import ( - ClearAll, - KeepLastN, - LLMSummarize, - _find_prefix_end, - _last_usage_tokens, - rebuild_chat_context, -) -from mellea.stdlib.components.chat import Message -from mellea.stdlib.components.react import ( - MELLEA_FINALIZER_TOOL, - ReactInitiator, - _mellea_finalize_tool, -) -from mellea.stdlib.context import ChatContext -from mellea.stdlib.frameworks.react import react - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _build_context(components: list[Component | CBlock]) -> ChatContext: - """Build a ChatContext from a list of components.""" - ctx = ChatContext() - for c in components: - ctx = ctx.add(c) - return ctx - - -def _msg(role: Message.Role, content: str) -> Message: - return Message(role=role, content=content) - - -def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: - """Build a ModelOutputThunk with a populated usage dict.""" - mot = ModelOutputThunk(value=value) - mot.generation.usage = { - "prompt_tokens": total_tokens, - "completion_tokens": 0, - "total_tokens": total_tokens, - } - return mot - - -# --------------------------------------------------------------------------- -# rebuild_chat_context -# --------------------------------------------------------------------------- - - -class TestRebuildChatContext: - def test_empty(self): - ctx = rebuild_chat_context([]) - assert ctx.as_list() == [] - - def test_round_trip(self): - components = [_msg("user", "hello"), _msg("assistant", "hi")] - ctx = rebuild_chat_context(components) - result = ctx.as_list() - assert len(result) == 2 - assert all(isinstance(c, Message) for c in result) - - def test_preserves_window_size(self): - ctx = rebuild_chat_context([_msg("user", "a")], window_size=3) - assert ctx._window_size == 3 - - -# --------------------------------------------------------------------------- -# _find_prefix_end -# --------------------------------------------------------------------------- - - -class TestFindPrefixEnd: - def test_no_initiator(self): - components = [_msg("user", "a"), _msg("assistant", "b")] - assert _find_prefix_end(components) == 0 - - def test_initiator_at_start(self): - components = [ReactInitiator("goal", []), _msg("user", "a")] - assert _find_prefix_end(components) == 1 - - def test_initiator_after_system_msg(self): - components = [ - _msg("system", "sys"), - ReactInitiator("goal", []), - _msg("user", "a"), - ] - assert _find_prefix_end(components) == 2 - - -# --------------------------------------------------------------------------- -# should_compact -# --------------------------------------------------------------------------- - - -class TestLastUsageTokens: - def test_no_thunk_returns_none(self): - ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) - assert _last_usage_tokens(ctx) is None - - def test_thunk_without_usage_returns_none(self): - ctx = _build_context([_msg("user", "a"), ModelOutputThunk(value="b")]) - assert _last_usage_tokens(ctx) is None - - def test_reads_total_tokens(self): - ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=150)]) - assert _last_usage_tokens(ctx) == 150 - - def test_falls_back_to_prompt_plus_completion(self): - mot = ModelOutputThunk(value="x") - mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} - ctx = _build_context([_msg("user", "a"), mot]) - assert _last_usage_tokens(ctx) == 60 - - def test_uses_most_recent_thunk(self): - ctx = _build_context([_thunk(100), _msg("user", "x"), _thunk(500)]) - assert _last_usage_tokens(ctx) == 500 - - -class TestShouldCompact: - def test_no_thunk_does_not_trigger(self): - ctx = _build_context([_msg("user", "a"), _msg("assistant", "b")]) - strategy = KeepLastN(keep_n=1, threshold=100) - assert strategy.should_compact(ctx) is False - - def test_below_threshold(self): - ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=50)]) - strategy = KeepLastN(keep_n=1, threshold=100) - assert strategy.should_compact(ctx) is False - - def test_above_threshold(self): - ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=500)]) - strategy = KeepLastN(keep_n=1, threshold=100) - assert strategy.should_compact(ctx) is True - - def test_zero_threshold_never_triggers(self): - ctx = _build_context([_msg("user", "a"), _thunk(total_tokens=10_000)]) - strategy = KeepLastN(keep_n=1, threshold=0) - assert strategy.should_compact(ctx) is False - - -# --------------------------------------------------------------------------- -# ClearAll -# --------------------------------------------------------------------------- - - -class TestClearAll: - @pytest.mark.asyncio - async def test_keeps_only_prefix(self): - initiator = ReactInitiator("find the answer", []) - components = [initiator, _msg("user", "a"), _msg("assistant", "b")] - ctx = _build_context(components) - - result = await ClearAll().compact(ctx) - result_list = result.as_list() - assert len(result_list) == 1 - assert isinstance(result_list[0], ReactInitiator) - - @pytest.mark.asyncio - async def test_empty_body_is_noop(self): - initiator = ReactInitiator("goal", []) - ctx = _build_context([initiator]) - - result = await ClearAll().compact(ctx) - assert len(result.as_list()) == 1 - - -# --------------------------------------------------------------------------- -# KeepLastN -# --------------------------------------------------------------------------- - - -class TestKeepLastN: - @pytest.mark.asyncio - async def test_keeps_prefix_and_last_n(self): - initiator = ReactInitiator("goal", []) - body = [_msg("user", str(i)) for i in range(10)] - ctx = _build_context([initiator, *body]) - - result = await KeepLastN(keep_n=3).compact(ctx) - result_list = result.as_list() - assert len(result_list) == 4 # 1 prefix + 3 body - assert isinstance(result_list[0], ReactInitiator) - # Last 3 body messages - for i, c in enumerate(result_list[1:]): - assert isinstance(c, Message) - assert c.content == str(7 + i) - - @pytest.mark.asyncio - async def test_fewer_than_n_is_noop(self): - initiator = ReactInitiator("goal", []) - body = [_msg("user", "a"), _msg("assistant", "b")] - ctx = _build_context([initiator, *body]) - - result = await KeepLastN(keep_n=5).compact(ctx) - # Should return original context unchanged - assert result is ctx - - @pytest.mark.asyncio - async def test_preserves_window_size(self): - initiator = ReactInitiator("goal", []) - body = [_msg("user", str(i)) for i in range(10)] - ctx = rebuild_chat_context([initiator, *body], window_size=7) - - result = await KeepLastN(keep_n=2).compact(ctx) - assert result._window_size == 7 - - -# --------------------------------------------------------------------------- -# LLMSummarize -# --------------------------------------------------------------------------- - - -@dataclass -class _ScriptedTurn: - """A single scripted backend response.""" - - value: str - tool_calls: dict[str, ModelToolCall] | None = None - total_tokens: int | None = None - - -class ScriptedBackend(Backend): - """Fake backend returning pre-scripted responses.""" - - def __init__(self, script: list[_ScriptedTurn]) -> None: - self._script = iter(script) - - async def _generate_from_context( - self, - action: Component[C] | CBlock, - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> tuple[ModelOutputThunk[C], Context]: - turn = next(self._script) - mot: ModelOutputThunk = ModelOutputThunk( - value=turn.value, tool_calls=turn.tool_calls - ) - mot._generate_log = GenerateLog(is_final_result=True) - if turn.total_tokens is not None: - mot.generation.usage = { - "prompt_tokens": turn.total_tokens, - "completion_tokens": 0, - "total_tokens": turn.total_tokens, - } - return mot, ctx.add(action).add(mot) - - async def generate_from_raw( - self, - actions: Sequence[Component[C] | CBlock], - ctx: Context, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> list[ModelOutputThunk]: - raise NotImplementedError - - -class TestLLMSummarize: - @pytest.mark.asyncio - async def test_raises_without_backend(self): - ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) - with pytest.raises(ValueError, match="backend"): - await LLMSummarize(keep_n=0).compact(ctx) - - @pytest.mark.asyncio - async def test_raises_without_goal(self): - ctx = _build_context([ReactInitiator("g", []), _msg("user", "a")]) - backend = ScriptedBackend([]) - with pytest.raises(ValueError, match="goal"): - await LLMSummarize(keep_n=0).compact(ctx, backend=backend) - - @pytest.mark.asyncio - async def test_summarizes_old_keeps_recent(self): - initiator = ReactInitiator("goal", []) - body = [_msg("user", f"msg-{i}") for i in range(6)] - ctx = _build_context([initiator, *body]) - - # The backend will return one summary when the summarization prompt is sent - backend = ScriptedBackend([_ScriptedTurn(value="Summary of old messages")]) - - result = await LLMSummarize(keep_n=2).compact(ctx, backend=backend, goal="goal") - result_list = result.as_list() - - # prefix (1) + summary message (1) + last 2 body = 4 - assert len(result_list) == 4 - assert isinstance(result_list[0], ReactInitiator) - # Summary message - assert isinstance(result_list[1], Message) - assert "[CONTEXT SUMMARY]" in result_list[1].content - # Recent messages preserved - assert result_list[2].content == "msg-4" - assert result_list[3].content == "msg-5" - - @pytest.mark.asyncio - async def test_fewer_than_n_is_noop(self): - initiator = ReactInitiator("goal", []) - body = [_msg("user", "a")] - ctx = _build_context([initiator, *body]) - backend = ScriptedBackend([]) - - result = await LLMSummarize(keep_n=5).compact(ctx, backend=backend, goal="goal") - assert result is ctx - - -# --------------------------------------------------------------------------- -# Integration: react() with compaction -# --------------------------------------------------------------------------- - - -def _make_tool(name: str, return_value: str = "tool_result") -> MelleaTool: - def _fn() -> str: - return return_value - - return MelleaTool.from_callable(_fn, name=name) - - -def _final_answer_call(answer: str = "42") -> _ScriptedTurn: - tool = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) - tc = ModelToolCall(name=MELLEA_FINALIZER_TOOL, func=tool, args={"answer": answer}) - return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) - - -def _tool_call_turn( - tool_name: str, - tool: MelleaTool, - thought: str = "thinking...", - total_tokens: int | None = None, -) -> _ScriptedTurn: - tc = ModelToolCall(name=tool_name, func=tool, args={}) - return _ScriptedTurn( - value=thought, tool_calls={tool_name: tc}, total_tokens=total_tokens - ) - - -class TestReactWithCompaction: - @pytest.mark.asyncio - @pytest.mark.integration - async def test_compaction_triggers_during_react(self): - """Compaction fires when last thunk's token usage exceeds threshold.""" - search = _make_tool("search", "found it") - backend = ScriptedBackend( - [ - _tool_call_turn("search", search, "step 1", total_tokens=200), - _tool_call_turn("search", search, "step 2", total_tokens=200), - _tool_call_turn("search", search, "step 3", total_tokens=200), - _final_answer_call("done"), - ] - ) - - result, _ctx = await react( - goal="find info", - context=ChatContext(), - backend=backend, - tools=[search], - loop_budget=10, - compaction=KeepLastN(keep_n=3, threshold=100), - ) - assert result.value == "done" - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_no_compaction_when_disabled(self): - """Without compaction params, react behaves identically to before.""" - backend = ScriptedBackend([_final_answer_call("42")]) - result, _ = await react( - goal="answer", - context=ChatContext(), - backend=backend, - tools=None, - loop_budget=5, - ) - assert result.value == "42" diff --git a/test/stdlib/frameworks/test_react_framework.py b/test/stdlib/frameworks/test_react_framework.py index e121a91f5..8ae2d0b7b 100644 --- a/test/stdlib/frameworks/test_react_framework.py +++ b/test/stdlib/frameworks/test_react_framework.py @@ -231,5 +231,217 @@ async def test_react_rejects_non_chat_context(): await react(goal="g", context=Mock(), backend=Mock(), tools=None) +# --- compaction integration --- + + +def test_pin_react_initiator_finds_initiator(): + from mellea.stdlib.components.chat import Message + from mellea.stdlib.components.react import pin_react_initiator + + components = [ + Message("system", "sys"), + ReactInitiator("solve x", []), + Message("user", "step 1"), + ] + # Pinned prefix = system + initiator = first two indices. + assert pin_react_initiator(components) == 2 + + +def test_pin_react_initiator_returns_zero_when_absent(): + from mellea.stdlib.components.chat import Message + from mellea.stdlib.components.react import pin_react_initiator + + components = [Message("user", "a"), Message("assistant", "b")] + assert pin_react_initiator(components) == 0 + + +def test_react_summary_prompt_default(): + """Without a goal the prompt has no GOAL: line and contains {conversation}.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt() + assert "{conversation}" in prompt + assert "GOAL:" not in prompt + assert "research progress" in prompt + assert "search queries" in prompt + assert "dead ends" in prompt + + +def test_react_summary_prompt_with_goal(): + """Goal is interpolated and the prompt still has the {conversation} placeholder.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="find papers on context compaction") + assert "GOAL: find papers on context compaction" in prompt + assert "{conversation}" in prompt + + +def test_react_summary_prompt_escapes_braces_in_goal(): + """Braces in the goal must survive str.format() in LLMSummarizeCompactor.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="solve {x: 1, y: 2}") + # After str.format(conversation=...), the goal should appear with literal braces. + rendered = prompt.format(conversation="") + assert "GOAL: solve {x: 1, y: 2}" in rendered + assert "" in rendered + + +def test_react_summary_prompt_works_with_llm_summarize_compactor(): + """The factory's output passes LLMSummarizeCompactor's template validation.""" + from mellea.stdlib.components.react import react_summary_prompt + from mellea.stdlib.context import LLMSummarizeCompactor + + # Should not raise on construction (template contains {conversation}). + LLMSummarizeCompactor(prompt_template=react_summary_prompt(goal="g")) + LLMSummarizeCompactor(prompt_template=react_summary_prompt()) + LLMSummarizeCompactor( + prompt_template=react_summary_prompt(goal="g", max_tokens_hint=2000) + ) + + +def test_react_summary_prompt_max_tokens_hint_omitted_by_default(): + """Without a hint, the prompt is byte-identical to the un-hinted form.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="g") + prompt_explicit_none = react_summary_prompt(goal="g", max_tokens_hint=None) + assert prompt == prompt_explicit_none + assert "Be at most" not in prompt + assert "tokens (roughly" not in prompt + + +def test_react_summary_prompt_max_tokens_hint_injects_bullet(): + """Positive hint adds a bullet with token + word estimates.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="g", max_tokens_hint=2000) + # The bullet sits after "structured clearly" and before "Context to summarize:". + assert "- Be at most ~2000 tokens (roughly 1500 words)" in prompt + assert "Prioritize density" in prompt + # Ordering: structured-clearly bullet comes before the length bullet, + # length bullet comes before the conversation marker. + sc_idx = prompt.index("structured clearly") + bullet_idx = prompt.index("Be at most ~2000") + conv_idx = prompt.index("Context to summarize:") + assert sc_idx < bullet_idx < conv_idx + + +def test_react_summary_prompt_max_tokens_hint_zero_or_negative_omits_bullet(): + """Non-positive hint values are treated as no hint.""" + from mellea.stdlib.components.react import react_summary_prompt + + base = react_summary_prompt() + assert react_summary_prompt(max_tokens_hint=0) == base + assert react_summary_prompt(max_tokens_hint=-1) == base + + +def test_react_summary_prompt_max_tokens_hint_word_estimate_scales(): + """Word estimate uses the ~0.75 words/token heuristic (int truncation).""" + from mellea.stdlib.components.react import react_summary_prompt + + # 1000 tokens → 750 words; 4000 → 3000. + assert "~1000 tokens (roughly 750 words)" in react_summary_prompt( + max_tokens_hint=1000 + ) + assert "~4000 tokens (roughly 3000 words)" in react_summary_prompt( + max_tokens_hint=4000 + ) + + +@pytest.mark.asyncio +async def test_react_invokes_per_turn_compactor(): + """The ``compactor=`` hook runs once per turn after the tool observation.""" + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _final_answer_call("done"), + ] + ) + + calls = [] + + class RecordingCompactor: + def compact(self, ctx, *, backend=None): + calls.append(len(ctx.as_list())) + return ctx # no-op compaction; we just observe + + result, _ctx = await react( + goal="find info", + context=ChatContext(), + backend=backend, + tools=[search], + loop_budget=10, + compactor=RecordingCompactor(), + ) + + # Two non-terminal turns each invoke the compactor; the final turn skips it. + assert result.value == "done" + assert len(calls) == 2 + # Per-turn context monotonically grows in this trace. + assert calls[0] < calls[1] + + +@pytest.mark.asyncio +async def test_react_runs_llm_summarize_compactor(): + """LLMSummarizeCompactor.compact is sync (hides async internally), so react() + just calls it like any other sync Compactor. + """ + from mellea.stdlib.components.react import pin_react_initiator + from mellea.stdlib.context import LLMSummarizeCompactor + + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [_tool_call_turn("search", search, "step 1"), _final_answer_call("done")] + ) + + # keep_n large → no actual summarisation fires; the test verifies that + # the sync compact() method is callable from inside the async react() + # loop without exception. + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=LLMSummarizeCompactor(keep_n=1000, pin_predicate=pin_react_initiator), + ) + assert result.value == "done" + assert any(isinstance(c, ReactInitiator) for c in ctx.as_list()) + + +@pytest.mark.asyncio +async def test_react_compactor_can_actually_compact(): + """A real WindowCompactor wired in via the per-turn hook truncates context.""" + from mellea.stdlib.components.react import pin_react_initiator + from mellea.stdlib.context import WindowCompactor + + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _tool_call_turn("search", search, "step 3"), + _final_answer_call("done"), + ] + ) + + result, ctx = await react( + goal="find info", + # Permissive per-add window so we isolate the per-turn compactor's effect. + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator), + ) + + # The ReactInitiator must survive thanks to pin_react_initiator. + assert any(isinstance(c, ReactInitiator) for c in ctx.as_list()) + assert result.value == "done" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib/test_base_context.py b/test/stdlib/test_base_context.py index 2fccb11fd..83a4b42f8 100644 --- a/test/stdlib/test_base_context.py +++ b/test/stdlib/test_base_context.py @@ -4,8 +4,8 @@ from mellea.stdlib.context import ChatContext, SimpleContext -def context_construction(cls: type[Context]): - tree0 = cls() +def context_construction(cls: type[Context], **kwargs): + tree0 = cls(**kwargs) tree1 = tree0.add(CBlock("abc")) assert tree1.previous_node == tree0 @@ -15,11 +15,14 @@ def context_construction(cls: type[Context]): def test_context_construction(): context_construction(SimpleContext) + # ChatContext defaults to WindowCompactor(5); a single add stays well under + # the window so the linked-list shape is identical to the pre-compaction + # behaviour. context_construction(ChatContext) -def large_context_construction(cls: type[Context]): - root = cls() +def large_context_construction(cls: type[Context], **kwargs): + root = cls(**kwargs) full_graph: Context = root for i in range(1000): @@ -31,7 +34,9 @@ def large_context_construction(cls: type[Context]): def test_large_context_construction(): large_context_construction(SimpleContext) - large_context_construction(ChatContext) + # ChatContext now applies real compaction at add() time; pass a window + # large enough that all 1000 components survive. + large_context_construction(ChatContext, window_size=2000) def test_render_view_for_simple_context(): @@ -48,7 +53,9 @@ def test_render_view_for_chat_context(): ctx = ChatContext(window_size=3) for i in range(5): ctx = ctx.add(CBlock(f"a {i}")) - assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" + # Compaction is now applied at add() time, so as_list and view_for_generation + # both reflect the sliding window of 3. + assert len(ctx.as_list()) == 3, "WindowCompactor(3) should keep 3 items" assert len(ctx.view_for_generation()) == 3, "Render size should be 3" # type: ignore diff --git a/test/stdlib/test_compactor.py b/test/stdlib/test_compactor.py new file mode 100644 index 000000000..bdb601049 --- /dev/null +++ b/test/stdlib/test_compactor.py @@ -0,0 +1,492 @@ +"""Tests for the ``Compactor`` protocol, ``WindowCompactor``, ``ThresholdCompactor``.""" + +from __future__ import annotations + +import pytest + +from mellea.core.base import ModelOutputThunk +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ( + ChatContext, + Compactor, + LLMSummarizeCompactor, + PinPredicate, + ThresholdCompactor, + WindowCompactor, + pin_nothing, + pin_system, + pin_system_and_initial_user, +) +from mellea.stdlib.context.compactor import _last_usage_tokens + + +def _msg(i: int) -> Message: + return Message(role="user", content=f"m{i}") + + +def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict.""" + mot = ModelOutputThunk(value=value) + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +class TestChatContextDefaults: + def test_default_has_no_compactor(self): + # Compaction is opt-in: bare ChatContext() retains full history. + ctx = ChatContext() + assert ctx._compactor is None + + def test_default_keeps_full_history(self): + ctx = ChatContext() + for i in range(20): + ctx = ctx.add(_msg(i)) + assert len(ctx.as_list()) == 20 + + def test_window_size_arg_constructs_window_compactor(self): + ctx = ChatContext(window_size=3) + assert isinstance(ctx._compactor, WindowCompactor) + assert ctx._compactor.size == 3 + + def test_passing_both_args_raises(self): + with pytest.raises(ValueError): + ChatContext(compactor=WindowCompactor(size=2), window_size=3) + + def test_explicit_compactor_overrides_default(self): + comp = WindowCompactor(size=2) + ctx = ChatContext(compactor=comp) + assert ctx._compactor is comp + + +class TestWindowCompactor: + def test_compact_keeps_last_n(self): + ctx = ChatContext(window_size=3) + for i in range(7): + ctx = ctx.add(_msg(i)) + items = ctx.as_list() + assert len(items) == 3 + assert [m.content for m in items] == ["m4", "m5", "m6"] + + def test_compact_does_not_mutate_original(self): + # Build with a permissive window so all 3 items are retained, then + # apply a tighter compactor manually (Pattern 2). + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_msg(0)) + ctx = ctx.add(_msg(1)) + ctx = ctx.add(_msg(2)) + before_compact = [m.content for m in ctx.as_list()] + compacted = WindowCompactor(size=2).compact(ctx) + # original unchanged + assert [m.content for m in ctx.as_list()] == before_compact + # compacted is shorter and a different object + assert compacted is not ctx + assert len(compacted.as_list()) == 2 + + def test_compact_preserves_compactor_on_result(self): + comp = WindowCompactor(size=2) + ctx = ChatContext(compactor=comp) + ctx = ctx.add(_msg(0)).add(_msg(1)).add(_msg(2)) + # subsequent adds keep using the same compactor + ctx = ctx.add(_msg(3)) + assert ctx._compactor is comp + assert len(ctx.as_list()) == 2 + + def test_view_for_generation_no_double_truncation(self): + ctx = ChatContext(window_size=3) + for i in range(7): + ctx = ctx.add(_msg(i)) + # add() already compacted; view should match the linear history exactly + view = ctx.view_for_generation() + assert view is not None + assert [m.content for m in view] == [m.content for m in ctx.as_list()] + + def test_negative_size_raises(self): + with pytest.raises(ValueError): + WindowCompactor(size=-1) + + def test_size_zero_clears_body(self): + # Regression: `[-0:]` evaluates to `[0:]` in Python, which would keep + # the entire body instead of nothing. size=0 must keep zero body items. + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=0).compact(ctx) + assert result.as_list() == [] + + def test_size_zero_keeps_pinned_prefix(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(3): + ctx = ctx.add(_msg(i)) + # Default pin_predicate=pin_system → system stays, body cleared. + result = WindowCompactor(size=0).compact(ctx) + items = result.as_list() + assert len(items) == 1 + assert items[0].content == "sys" + + def test_pins_leading_system_message(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="You are helpful.")) + for i in range(5): + ctx = ctx.add(_msg(i)) + # Apply WindowCompactor(size=2) manually — keep system + last 2 body. + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert len(items) == 3 + assert isinstance(items[0], Message) and items[0].role == "system" + assert [m.content for m in items[1:]] == ["m3", "m4"] + + def test_pins_multiple_leading_system_messages(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys1")) + ctx = ctx.add(Message(role="system", content="sys2")) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert [m.content for m in items[:2]] == ["sys1", "sys2"] + assert [m.content for m in items[2:]] == ["m3", "m4"] + + def test_does_not_pin_non_contiguous_system(self): + # System message in the middle is NOT pinned — only the contiguous prefix. + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_msg(0)) # body starts here + ctx = ctx.add(Message(role="system", content="late-sys")) + for i in range(1, 6): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert len(items) == 2 + assert "late-sys" not in [getattr(m, "content", None) for m in items] + + def test_no_system_message_pure_last_n(self): + # Without any system prefix, behaviour is pure last-N (matches Phase 2 semantics). + ctx = ChatContext(window_size=10_000) + for i in range(7): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=3).compact(ctx) + items = result.as_list() + assert [m.content for m in items] == ["m4", "m5", "m6"] + + +class TestCompactorProtocol: + def test_user_class_satisfies_protocol(self): + """A plain class with the right method should be a Compactor.""" + + class Identity: + def compact(self, ctx, *, backend=None): + return ctx + + # structural subtyping check — at runtime this is just isinstance against Protocol + # which requires `runtime_checkable` to actually work; instead assert duck-typing. + c = Identity() + ctx = ChatContext(compactor=c) + ctx = ctx.add(_msg(0)) + # Identity returns ctx unchanged, so we still see m0 + assert [m.content for m in ctx.as_list()] == ["m0"] + + def test_pattern_2_manual_compaction(self): + """Pattern 2: caller invokes compactor.compact() directly.""" + comp = WindowCompactor(size=2) + # context with no auto-compaction would be tricky to construct under the + # new defaults; instead use a window large enough that auto-compaction + # never fires, then apply comp manually. + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + assert len(ctx.as_list()) == 5 + ctx2 = comp.compact(ctx) + assert len(ctx2.as_list()) == 2 + # original still untouched + assert len(ctx.as_list()) == 5 + + +class TestLastUsageTokens: + def test_no_thunk_returns_none(self): + ctx = ChatContext(window_size=100).add(_msg(0)) + assert _last_usage_tokens(ctx) is None + + def test_thunk_without_usage_returns_none(self): + ctx = ChatContext(window_size=100).add(_msg(0)).add(ModelOutputThunk(value="x")) + assert _last_usage_tokens(ctx) is None + + def test_reads_total_tokens(self): + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(150)) + assert _last_usage_tokens(ctx) == 150 + + def test_falls_back_to_prompt_plus_completion(self): + mot = ModelOutputThunk(value="x") + mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} + ctx = ChatContext(window_size=100).add(_msg(0)).add(mot) + assert _last_usage_tokens(ctx) == 60 + + def test_uses_most_recent_thunk(self): + ctx = ( + ChatContext(window_size=100).add(_thunk(100)).add(_msg(0)).add(_thunk(500)) + ) + assert _last_usage_tokens(ctx) == 500 + + +class TestThresholdCompactor: + def test_below_threshold_returns_input(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=1000) + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(50)) + # 5 components but inner not invoked because token count (50) <= threshold (1000) + for i in range(1, 6): + ctx = ctx.add(_msg(i)) + result = gated.compact(ctx) + assert result is ctx + + def test_above_threshold_runs_inner(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=100) + # Build a context with the last thunk reporting >threshold tokens. + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + ctx = ctx.add(_thunk(500)) + result = gated.compact(ctx) + # Inner was invoked → only last 2 components retained. + assert len(result.as_list()) == 2 + + def test_no_thunk_no_compaction(self): + """No thunk means no usage info — gate stays closed.""" + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=100) + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = gated.compact(ctx) + assert result is ctx + + def test_zero_threshold_disables_gate(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=0) + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(10_000)) + result = gated.compact(ctx) + # Threshold 0 means "never trigger" — input passes through. + assert result is ctx + + +class TestPinPredicates: + def test_pin_nothing(self): + assert pin_nothing([_msg(0), _msg(1)]) == 0 + assert pin_nothing([]) == 0 + + def test_pin_system_zero_when_no_system(self): + assert pin_system([_msg(0), _msg(1)]) == 0 + + def test_pin_system_counts_contiguous(self): + components = [ + Message(role="system", content="s1"), + Message(role="system", content="s2"), + _msg(0), + Message(role="system", content="late-s"), # not pinned — non-contiguous + ] + assert pin_system(components) == 2 + + def test_pin_system_and_initial_user_with_both(self): + components = [ + Message(role="system", content="s1"), + Message(role="user", content="goal"), + Message(role="assistant", content="ack"), + ] + assert pin_system_and_initial_user(components) == 2 + + def test_pin_system_and_initial_user_no_user(self): + components = [ + Message(role="system", content="s1"), + Message(role="assistant", content="x"), + ] + # First non-system is "assistant", not "user" — not pinned beyond system. + assert pin_system_and_initial_user(components) == 1 + + def test_pin_system_and_initial_user_user_only(self): + components = [ + Message(role="user", content="goal"), + Message(role="assistant", content="ok"), + ] + assert pin_system_and_initial_user(components) == 1 + + +class TestWindowCompactorPredicate: + def test_pin_nothing_pure_last_n(self): + comp = WindowCompactor(size=2, pin_predicate=pin_nothing) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + assert len(items) == 2 + # System is dropped because predicate returned 0. + assert "sys" not in [getattr(m, "content", None) for m in items] + + def test_pin_system_and_initial_user_protects_first_user(self): + comp = WindowCompactor(size=2, pin_predicate=pin_system_and_initial_user) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + ctx = ctx.add(Message(role="user", content="goal")) + for i in range(6): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + # prefix (sys + goal) + last 2 body = 4 + assert len(items) == 4 + assert items[0].content == "sys" + assert items[1].content == "goal" + + def test_custom_predicate(self): + # Predicate that pins the first 3 components unconditionally. + def pin_first_3(components): + return min(3, len(components)) + + comp = WindowCompactor(size=2, pin_predicate=pin_first_3) + ctx = ChatContext(window_size=10_000) + for i in range(8): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + # prefix (m0, m1, m2) + last 2 of body (m6, m7) = 5 + assert [m.content for m in items] == ["m0", "m1", "m2", "m6", "m7"] + + +# --------------------------------------------------------------------------- # +# LLMSummarizeCompactor # +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def scripted_summary_backend(): + """Lazy-built fake backend that returns a fixed summary on each generate call.""" + from collections.abc import Sequence + + from mellea.core.backend import Backend, BaseModelSubclass + from mellea.core.base import C, GenerateLog + + class FakeBackend(Backend): + def __init__(self, summary: str = "SUMMARY-OF-OLD") -> None: + self.summary = summary + self.calls = 0 + + async def _generate_from_context( + self, + action, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + self.calls += 1 + mot = ModelOutputThunk(value=self.summary) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + raise NotImplementedError + + return FakeBackend() + + +class TestLLMSummarizeCompactor: + def test_negative_keep_n_raises(self): + with pytest.raises(ValueError): + LLMSummarizeCompactor(keep_n=-1) + + def test_prompt_template_must_have_placeholder(self): + with pytest.raises(ValueError, match="conversation"): + LLMSummarizeCompactor(prompt_template="no placeholder here") + + def test_compact_is_sync(self): + import inspect + + comp = LLMSummarizeCompactor() + # Sync from the outside even though the implementation calls async backend code. + assert not inspect.iscoroutinefunction(comp.compact) + + def test_raises_without_backend(self): + comp = LLMSummarizeCompactor() + ctx = ChatContext(window_size=10_000) + for i in range(3): + ctx = ctx.add(_msg(i)) + with pytest.raises(ValueError, match="backend"): + comp.compact(ctx) + + def test_short_body_is_noop(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=5) + ctx = ChatContext(window_size=10_000) + for i in range(3): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + # body length (3) <= keep_n (5) → no-op, backend not called + assert result is ctx + assert scripted_summary_backend.calls == 0 + + def test_summarises_old_keeps_recent(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=2) + ctx = ChatContext(window_size=10_000) + for i in range(6): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + # summary (1) + last 2 verbatim = 3 + assert len(items) == 3 + assert "[CONTEXT SUMMARY]" in items[0].content + assert items[1].content == "m4" + assert items[2].content == "m5" + assert scripted_summary_backend.calls == 1 + + def test_pin_predicate_preserves_prefix(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=1, pin_predicate=pin_system) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(4): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + # system (pinned) + summary + last 1 verbatim = 3 + assert items[0].role == "system" + assert items[0].content == "sys" + assert "[CONTEXT SUMMARY]" in items[1].content + assert items[2].content == "m3" + + def test_does_not_mutate_original(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + before = [m.content for m in ctx.as_list()] + comp.compact(ctx, backend=scripted_summary_backend) + assert [m.content for m in ctx.as_list()] == before + + def test_satisfies_compactor_protocol(self): + comp: Compactor = LLMSummarizeCompactor() + # Just a typing-level check that the assignment is accepted. + assert callable(comp.compact) + + @pytest.mark.asyncio + async def test_works_inside_running_event_loop(self, scripted_summary_backend): + """compact() is callable from within an async function — uses worker thread.""" + comp = LLMSummarizeCompactor(keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + # No await: this is a sync call from inside an async test. + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + assert "[CONTEXT SUMMARY]" in items[0].content + assert items[1].content == "m3"