diff --git a/docs/examples/context/README.md b/docs/examples/context/README.md index dde027bc5..e7b8b3752 100644 --- a/docs/examples/context/README.md +++ b/docs/examples/context/README.md @@ -1,13 +1,15 @@ # Context Examples -This directory contains examples demonstrating how to work with Mellea's context system, particularly when using sampling strategies and validation. +This directory contains examples demonstrating how to work with Mellea's context system: inspecting per-attempt contexts produced by sampling strategies, and shrinking contexts with the `Compactor` protocol. ## Files ### contexts_with_sampling.py + Shows how to retrieve and inspect context information when using sampling strategies and validation. **Key Features:** + - Using `RejectionSamplingStrategy` with requirements - Accessing `SamplingResult` objects to inspect generation attempts - Retrieving context for different generation attempts @@ -15,10 +17,34 @@ Shows how to retrieve and inspect context information when using sampling strate - Understanding the context tree structure **Usage:** -```bash + +``` python docs/examples/context/contexts_with_sampling.py ``` +### window_compactor.py + +`WindowCompactor` — opt-in by passing `compactor=` (or the `window_size=` sugar). Demonstrates system-prefix pinning, `pin_system_and_initial_user`, `pin_nothing` (pure last-N), and `size=0` to clear the body. + +### threshold_compactor.py + +`ThresholdCompactor` — gate an inner compactor on the conversation's cumulative token size. The reading is taken from the most recent `ModelOutputThunk`'s `total_tokens`, which for a chat backend equals `prompt_tokens` (full conversation history sent to the model) + `completion_tokens` (reply). The gate fires once the running conversation size crosses the threshold; once compaction shrinks the context, the next call produces a smaller reading and the gate closes again. + +### custom_compactor.py + +Implement the `Compactor` protocol with a plain class (no inheritance). Shows Pattern 1 (wired into `ChatContext`) and Pattern 2 (manual `compact()` call). + +### react_compaction.py + +Compose the ReACT loop with a sync `Compactor`. Two integration points: + +- **Per-add** — wire a `Compactor` onto the `ChatContext` so it runs every time `react()` appends a Message, ToolMessage, or thunk. +- **Per-turn** — pass `compactor=` to `react()`; it fires once per ReACT iteration after the tool observation. + +`LLMSummarizeCompactor` is also a sync `Compactor` — it hides the async backend call internally (worker thread when called from an already-running event loop) so callers don't have to think about sync vs async. + +Use `pin_react_initiator` (from `mellea.stdlib.components.react`) as the predicate so the goal and tool registration survive compaction. + ## Concepts Demonstrated - **Sampling Results**: Working with `SamplingResult` objects @@ -26,6 +52,8 @@ python docs/examples/context/contexts_with_sampling.py - **Multiple Attempts**: Examining different generation attempts - **Context Trees**: Understanding how contexts link together - **Validation Context**: Inspecting how requirements were evaluated +- **Compaction Protocol**: Sync `Compactor` for per-`add()` shrinking +- **Pin Predicates**: Auto-protect leading system messages or the user's initial prompt during compaction ## Key APIs @@ -48,8 +76,23 @@ gen_ctx.previous_node.node_data val_ctx.node_data ``` +```python +# Wire a compactor into a ChatContext (Pattern 1 — runs on every add()) +from mellea.stdlib.context import ChatContext, WindowCompactor, ThresholdCompactor + +ctx = ChatContext(compactor=WindowCompactor(size=5)) # default: pin_system +ctx = ChatContext(window_size=5) # sugar for the line above +ctx = ChatContext( + compactor=ThresholdCompactor(WindowCompactor(size=5), threshold=8000), +) + +# Manual compaction (Pattern 2) +ctx = WindowCompactor(size=0).compact(ctx) # drop body, keep pinned prefix +``` + ## Related Documentation -- See `mellea/stdlib/context.py` for context implementation +- See `mellea/stdlib/context/` for context and compactor implementations - See `mellea/stdlib/sampling/` for sampling strategies -- See `docs/dev/spans.md` for context architecture details \ No newline at end of file +- See `mellea/stdlib/frameworks/react.py` for the ReACT loop +- See `docs/dev/spans.md` for context architecture details diff --git a/docs/examples/context/custom_compactor.py b/docs/examples/context/custom_compactor.py new file mode 100644 index 000000000..663f21b4a --- /dev/null +++ b/docs/examples/context/custom_compactor.py @@ -0,0 +1,63 @@ +# pytest: unit +"""Implementing the Compactor protocol — anything with ``compact()`` works. + +The protocol is structurally typed: a class with a ``compact(ctx, *, +backend=None) -> ChatContext`` method is a valid Compactor. No +inheritance is required. +""" + +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ChatContext, Compactor +from mellea.stdlib.context.chat import _rebuild_chat_context + + +class TruncateOldest: + """Drop only the very first body component each call. + + Demonstrates the smallest possible Compactor implementation. Pattern + 1 (wired into ``ChatContext``) means each ``add()`` removes the + oldest item then appends — net result: the context never grows. + """ + + def compact(self, ctx, *, backend=None): + items = ctx.as_list() + if len(items) <= 1: + return ctx + return _rebuild_chat_context(items[1:], compactor=ctx._compactor) + + +def pattern_1_wired_into_context(): + """Pattern 1: compactor lives on the context, runs in ``add()``.""" + ctx = ChatContext(compactor=TruncateOldest()) + for i in range(4): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3'] (oldest dropped before each append) + + +def pattern_2_manual_call(): + """Pattern 2: caller invokes ``compact()`` directly between turns.""" + ctx = ChatContext(window_size=10_000) # permissive — no auto-compaction + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + truncated = TruncateOldest().compact(ctx) + return [m.content for m in truncated.as_list()] + + +def structural_typing_check(): + """The Compactor protocol is satisfied structurally, no inheritance.""" + c: Compactor = TruncateOldest() # mypy-checked Protocol assignment + return type(c).__name__ + + +if __name__ == "__main__": + for fn in [pattern_1_wired_into_context, pattern_2_manual_call]: + print(f"--- {fn.__name__} ---") + print(fn()) + print(f"structural typing: {structural_typing_check()} satisfies Compactor") + + +def test_custom_compactor_examples(): + assert pattern_1_wired_into_context() == ["msg 3"] + assert pattern_2_manual_call() == ["msg 1", "msg 2", "msg 3", "msg 4"] + assert structural_typing_check() == "TruncateOldest" diff --git a/docs/examples/context/react_compaction.py b/docs/examples/context/react_compaction.py new file mode 100644 index 000000000..baa19c3dd --- /dev/null +++ b/docs/examples/context/react_compaction.py @@ -0,0 +1,235 @@ +# pytest: unit +"""Compose the ReACT loop with a sync `Compactor`. + +Two integration points are available, and they're complementary: + +1. **Per-add** — the `ChatContext`'s own compactor runs every time the + ReACT loop appends a Message, ToolMessage, or thunk. This is fine + for cheap strategies like `WindowCompactor`. +2. **Per-turn** — pass `compactor=` to ``react(...)`` to invoke a + compactor once per ReACT iteration after the tool observation. Use + it for heavier strategies that should fire at turn boundaries + instead of on every component append. + +In both cases use ``pin_react_initiator`` (from +``mellea.stdlib.components.react``) so the goal and tool registration +survive compaction. + +This example exercises the wiring end-to-end against a fake backend so +no LLM is required. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from dataclasses import dataclass + +from mellea.backends.tools import MelleaTool +from mellea.core.backend import Backend, BaseModelSubclass +from mellea.core.base import ( + C, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, + ModelToolCall, +) +from mellea.stdlib.components.react import ( + MELLEA_FINALIZER_TOOL, + ReactInitiator, + _mellea_finalize_tool, + pin_react_initiator, +) +from mellea.stdlib.context import ChatContext, WindowCompactor +from mellea.stdlib.frameworks.react import react + +# --------------------------------------------------------------------------- # +# Fake backend so the example runs without an LLM # +# --------------------------------------------------------------------------- # + + +@dataclass +class _ScriptedTurn: + value: str + tool_calls: dict[str, ModelToolCall] | None = None + + +class ScriptedBackend(Backend): + """Returns pre-scripted responses; no real model is called.""" + + def __init__(self, script: list[_ScriptedTurn]) -> None: + self._script = iter(script) + + async def _generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + turn = next(self._script) + mot: ModelOutputThunk = ModelOutputThunk( + value=turn.value, tool_calls=turn.tool_calls + ) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +def _tool(name: str, return_value: str = "ok") -> MelleaTool: + def _fn() -> str: + return return_value + + return MelleaTool.from_callable(_fn, name=name) + + +def _tool_call(tool_name: str, tool: MelleaTool, thought: str) -> _ScriptedTurn: + tc = ModelToolCall(name=tool_name, func=tool, args={}) + return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + + +def _final(answer: str) -> _ScriptedTurn: + finalizer = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) + tc = ModelToolCall( + name=MELLEA_FINALIZER_TOOL, func=finalizer, args={"answer": answer} + ) + return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) + + +# --------------------------------------------------------------------------- # +# Pattern A — per-add compaction wired into the ChatContext # +# --------------------------------------------------------------------------- # + + +async def per_add_compaction(): + """A `WindowCompactor(pin_react_initiator)` on the ChatContext compacts + on every ``add()`` — Messages, ToolMessages, thunks. The ReactInitiator + stays pinned across the whole loop. + """ + search = _tool("search") + backend = ScriptedBackend( + [ + _tool_call("search", search, "step 1"), + _tool_call("search", search, "step 2"), + _tool_call("search", search, "step 3"), + _final("done"), + ] + ) + ctx = ChatContext( + compactor=WindowCompactor(size=3, pin_predicate=pin_react_initiator) + ) + result, ctx = await react( + goal="find info", context=ctx, backend=backend, tools=[search], loop_budget=10 + ) + return ( + result.value, + any(isinstance(c, ReactInitiator) for c in ctx.as_list()), + len(ctx.as_list()), + ) + + +# --------------------------------------------------------------------------- # +# Pattern B — per-turn compaction passed to react() # +# --------------------------------------------------------------------------- # + + +async def per_turn_compaction(): + """Pass ``compactor=`` to ``react`` for once-per-turn invocation. + + Use a permissive ``ChatContext`` (large window) so the per-add path is + effectively disabled — only the per-turn hook drives compaction. + """ + search = _tool("search") + backend = ScriptedBackend( + [ + _tool_call("search", search, "step 1"), + _tool_call("search", search, "step 2"), + _tool_call("search", search, "step 3"), + _final("done"), + ] + ) + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator), + ) + return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list())) + + +# --------------------------------------------------------------------------- # +# Pattern C — LLM-driven summarisation # +# --------------------------------------------------------------------------- # + + +async def llm_summarize_compaction(): + """Wire :class:`LLMSummarizeCompactor` into ``react()``. + + ``LLMSummarizeCompactor`` implements the sync :class:`Compactor` + protocol — its ``compact`` method internally orchestrates the async + backend call (running it on a worker thread when invoked from inside + an event loop). From ``react()``'s perspective it's just another + sync compactor. + + To keep the scripted backend simple, this example sets ``keep_n`` + large enough that summarisation never fires (no LLM call is needed). + Real usage would pair it with ``ThresholdCompactor`` so it only + activates once the conversation crosses a token budget. See + ``TestLLMSummarizeCompactor`` in ``test/stdlib/test_compactor.py`` for + unit tests that exercise the actual summary path. + """ + from mellea.stdlib.context import LLMSummarizeCompactor + + search = _tool("search") + backend = ScriptedBackend([_tool_call("search", search, "step 1"), _final("done")]) + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + # keep_n=1000 → no summarisation triggers in this short script; + # the example just shows the async compactor is wired correctly. + compactor=LLMSummarizeCompactor(keep_n=1000, pin_predicate=pin_react_initiator), + ) + return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list())) + + +if __name__ == "__main__": + print(f"per_add_compaction: {asyncio.run(per_add_compaction())}") + print(f"per_turn_compaction: {asyncio.run(per_turn_compaction())}") + print(f"llm_summarize_compact: {asyncio.run(llm_summarize_compaction())}") + + +def test_per_add_compaction(): + answer, has_initiator, _length = asyncio.run(per_add_compaction()) + assert answer == "done" + assert has_initiator + + +def test_per_turn_compaction(): + answer, has_initiator = asyncio.run(per_turn_compaction()) + assert answer == "done" + assert has_initiator + + +def test_llm_summarize_compaction(): + answer, has_initiator = asyncio.run(llm_summarize_compaction()) + assert answer == "done" + assert has_initiator diff --git a/docs/examples/context/threshold_compactor.py b/docs/examples/context/threshold_compactor.py new file mode 100644 index 000000000..120eba07c --- /dev/null +++ b/docs/examples/context/threshold_compactor.py @@ -0,0 +1,57 @@ +# pytest: unit +"""ThresholdCompactor — gate an inner Compactor on conversation size. + +Reads ``ModelOutputThunk.generation.usage`` from the most recent thunk +in the context. For a chat backend, ``total_tokens`` on that thunk is +``prompt_tokens`` (full conversation history sent to the model) plus +``completion_tokens`` (the reply), so it tracks *cumulative* context +size — not just one call's isolated tokens. The inner compactor fires +once that running size exceeds the configured threshold. +""" + +from mellea.core.base import ModelOutputThunk +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ChatContext, ThresholdCompactor, WindowCompactor + + +def _thunk(total_tokens: int) -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict (test helper).""" + mot = ModelOutputThunk(value="") + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +def below_threshold_passthrough(): + """Token usage is below threshold → inner compactor is NOT invoked.""" + gated = ThresholdCompactor(WindowCompactor(size=2), threshold=1000) + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + ctx = ctx.add(_thunk(50)) # only 50 tokens — below 1000 + out = gated.compact(ctx) + return len(out.as_list()) # 6 (5 messages + thunk) — unchanged + + +def above_threshold_compacts(): + """Token usage exceeds threshold → inner compactor runs.""" + gated = ThresholdCompactor(WindowCompactor(size=2), threshold=1000) + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + ctx = ctx.add(_thunk(2000)) # 2000 tokens — over the gate + out = gated.compact(ctx) + return len(out.as_list()) # 2 — WindowCompactor(size=2) ran + + +if __name__ == "__main__": + print(f"below_threshold_passthrough: {below_threshold_passthrough()}") + print(f"above_threshold_compacts: {above_threshold_compacts()}") + + +def test_threshold_compactor_examples(): + assert below_threshold_passthrough() == 6 + assert above_threshold_compacts() == 2 diff --git a/docs/examples/context/window_compactor.py b/docs/examples/context/window_compactor.py new file mode 100644 index 000000000..320dd0e31 --- /dev/null +++ b/docs/examples/context/window_compactor.py @@ -0,0 +1,101 @@ +# pytest: unit +"""WindowCompactor — keep the last N body components. + +Demonstrates the default behaviour, the ``window_size=`` sugar on +``ChatContext``, and how the auto-pinned system prefix is preserved. +""" + +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ( + ChatContext, + WindowCompactor, + pin_nothing, + pin_system_and_initial_user, +) + + +def basic_window(): + """``ChatContext()`` keeps the full history by default; opt in via + ``compactor=`` to start truncating. + """ + ctx = ChatContext(compactor=WindowCompactor(size=5)) + for i in range(8): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3', 'msg 4', 'msg 5', 'msg 6', 'msg 7'] + + +def window_size_sugar(): + """``window_size=`` is sugar for ``WindowCompactor(size=...)``.""" + ctx = ChatContext(window_size=3) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3', 'msg 4', 'msg 5'] + + +def system_prefix_pinned(): + """Default predicate ``pin_system`` keeps a leading system message.""" + ctx = ChatContext(window_size=3) + ctx = ctx.add(Message("system", "You are a helpful assistant.")) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + # → [('system', '...'), ('user', 'msg 3'), ('user', 'msg 4'), ('user', 'msg 5')] + + +def pin_initial_user_too(): + """Use ``pin_system_and_initial_user`` to also keep the user's first turn.""" + ctx = ChatContext( + compactor=WindowCompactor(size=3, pin_predicate=pin_system_and_initial_user) + ) + ctx = ctx.add(Message("system", "You are helpful.")) + ctx = ctx.add(Message("user", "What is the capital of France?")) + for i in range(6): + ctx = ctx.add(Message("assistant", f"reply {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + + +def pure_last_n(): + """``pin_nothing`` disables prefix pinning — the system message is dropped.""" + ctx = ChatContext(compactor=WindowCompactor(size=3, pin_predicate=pin_nothing)) + ctx = ctx.add(Message("system", "ignored after a few turns")) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + + +def clear_body_keep_prefix(): + """``size=0`` drops the body entirely while keeping the pinned prefix.""" + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message("system", "You are helpful.")) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + cleared = WindowCompactor(size=0).compact(ctx) + return [(m.role, m.content) for m in cleared.as_list()] + # → [('system', 'You are helpful.')] + + +if __name__ == "__main__": + for fn in [ + basic_window, + window_size_sugar, + system_prefix_pinned, + pin_initial_user_too, + pure_last_n, + clear_body_keep_prefix, + ]: + print(f"--- {fn.__name__} ---") + print(fn()) + + +def test_window_compactor_examples(): + """Smoke test all examples — invariants documented in each docstring.""" + assert basic_window() == ["msg 3", "msg 4", "msg 5", "msg 6", "msg 7"] + assert window_size_sugar() == ["msg 3", "msg 4", "msg 5"] + assert system_prefix_pinned()[0] == ("system", "You are a helpful assistant.") + pinned = pin_initial_user_too() + assert pinned[0] == ("system", "You are helpful.") + assert pinned[1] == ("user", "What is the capital of France?") + assert all(role == "user" for role, _ in pure_last_n()) + assert clear_body_keep_prefix() == [("system", "You are helpful.")] diff --git a/mellea/stdlib/components/react.py b/mellea/stdlib/components/react.py index 8f08fe8a0..dfc423043 100644 --- a/mellea/stdlib/components/react.py +++ b/mellea/stdlib/components/react.py @@ -32,6 +32,102 @@ def _mellea_finalize_tool(answer: str) -> str: return answer +def pin_react_initiator(components: list[Component | CBlock]) -> int: + """A ``PinPredicate`` that pins everything up to and including the first ``ReactInitiator``. + + Plug it into any compactor in :mod:`mellea.stdlib.context` that takes a + ``pin_predicate`` (e.g. :class:`WindowCompactor`, + :class:`ThresholdCompactor`'s inner compactor) so the react goal and + tool registration survive compaction: + + from mellea.stdlib.context import ChatContext, WindowCompactor + from mellea.stdlib.components.react import pin_react_initiator + + ctx = ChatContext( + compactor=WindowCompactor(size=5, pin_predicate=pin_react_initiator), + ) + result, _ = await react(goal=..., context=ctx, ...) + + Returns ``0`` when no ``ReactInitiator`` is found, so a context that + has not yet been seeded with a react goal compacts as if there were + no prefix. + """ + for i, c in enumerate(components): + if isinstance(c, ReactInitiator): + return i + 1 + return 0 + + +def react_summary_prompt( + goal: str | None = None, + max_tokens_hint: int | None = None, +) -> str: + """Build a research-flavoured summary prompt for :class:`LLMSummarizeCompactor`. + + Returns a template with a ``{conversation}`` placeholder that + :class:`LLMSummarizeCompactor` fills in at compaction time. Pass the + react goal via ``goal=`` to anchor the summarisation around the + objective; with ``goal=None`` the ``GOAL:`` line is omitted. + + Pass ``max_tokens_hint=N`` to inject a soft length-cap bullet + ("Be at most ~N tokens") into the summarizer's instructions. The hint + is a plan-time anchor for the model — combine it with a hard + ``max_tokens`` API arg on the summarizer's LLM call to enforce. + ``max_tokens_hint=None`` (default) or non-positive values omit the + bullet, so the prompt is byte-identical to the un-hinted form. + + Curly braces in ``goal`` are escaped so :meth:`str.format` (used by the + compactor) preserves them as literal characters. + + Example:: + + from mellea.stdlib.components.react import ( + pin_react_initiator, + react_summary_prompt, + ) + from mellea.stdlib.context import LLMSummarizeCompactor + + compactor = LLMSummarizeCompactor( + keep_n=5, + pin_predicate=pin_react_initiator, + prompt_template=react_summary_prompt( + goal="find papers on X", + max_tokens_hint=2000, + ), + ) + """ + if goal is not None: + # Escape braces so .format() in the compactor keeps them literal. + safe_goal = goal.replace("{", "{{").replace("}", "}}") + goal_block = f"GOAL: {safe_goal}\n\n" + else: + goal_block = "" + if max_tokens_hint is not None and max_tokens_hint > 0: + # Rough heuristic: ~0.75 words per token for English research text. + words_estimate = int(max_tokens_hint * 0.75) + length_bullet = ( + f"- Be at most ~{max_tokens_hint} tokens (roughly " + f"{words_estimate} words). Prioritize density: drop redundant " + "or ancillary detail.\n" + ) + else: + length_bullet = "" + return ( + "You are summarizing research progress to maintain context " + "within token limits.\n\n" + f"{goal_block}" + "Provide a comprehensive summary of the research context below. " + "Your summary should:\n" + "- Preserve ALL specific facts, numbers, names, URLs, and search " + "queries found\n" + "- Note which tools were called and what results were obtained\n" + "- Highlight key findings and any dead ends encountered\n" + "- Be structured clearly so the research can continue seamlessly\n" + f"{length_bullet}" + "\nContext to summarize:\n{conversation}" + ) + + class ReactInitiator(Component[str]): """`ReactInitiator` is used at the start of the ReACT loop to prime the model. diff --git a/mellea/stdlib/context.py b/mellea/stdlib/context.py deleted file mode 100644 index f0e2e8b2f..000000000 --- a/mellea/stdlib/context.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Concrete ``Context`` implementations for common conversation patterns. - -Provides ``ChatContext``, which accumulates all turns in a sliding-window chat history -(configurable via ``window_size``), and ``SimpleContext``, in which each interaction -is treated as a stateless single-turn exchange (no prior history is passed to the -model). Import ``ChatContext`` for multi-turn conversations and ``SimpleContext`` when -you want each call to the model to be independent. -""" - -from __future__ import annotations - -# Leave unused `ContextTurn` import for import ergonomics. -from ..core import CBlock, Component, Context, ContextTurn - - -class ChatContext(Context): - """Initializes a chat context with unbounded window_size and is_chat=True by default. - - Args: - window_size (int | None): Maximum number of context turns to include when - calling ``view_for_generation``. ``None`` (the default) means the full - history is always returned. - """ - - def __init__(self, *, window_size: int | None = None): - """Initialize ChatContext with an optional sliding-window size.""" - super().__init__() - self._window_size = window_size - - def add(self, c: Component | CBlock) -> ChatContext: - """Add a new component or CBlock to the context and return the updated context. - - Args: - c (Component | CBlock): The component or content block to append. - - Returns: - ChatContext: A new ``ChatContext`` with the added entry, preserving the - current ``window_size`` setting. - """ - new = ChatContext.from_previous(self, c) - new._window_size = self._window_size - return new - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Return the context entries to pass to the model, respecting the configured window. - - Uses the ``window_size`` set during initialisation to limit how many past - turns are included. ``None`` is returned when the underlying history is - non-linear. - - Returns: - list[Component | CBlock] | None: Ordered list of context entries up to - ``window_size`` turns, or ``None`` if the history is non-linear. - """ - return self.as_list(self._window_size) - - -class SimpleContext(Context): - """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" - - def add(self, c: Component | CBlock) -> SimpleContext: - """Add a new component or CBlock to the context and return the updated context. - - Args: - c (Component | CBlock): The component or content block to record. - - Returns: - SimpleContext: A new ``SimpleContext`` containing only the added entry; - prior history is not retained. - """ - return SimpleContext.from_previous(self, c) - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Return an empty list, since ``SimpleContext`` does not pass history to the model. - - Each call to the model is treated as a stateless, independent exchange. - No prior turns are forwarded. - - Returns: - list[Component | CBlock] | None: Always an empty list. - """ - return [] diff --git a/mellea/stdlib/context/__init__.py b/mellea/stdlib/context/__init__.py new file mode 100644 index 000000000..60bf94d94 --- /dev/null +++ b/mellea/stdlib/context/__init__.py @@ -0,0 +1,45 @@ +"""Concrete ``Context`` implementations and the ``Compactor`` protocol. + +Provides: + +- :class:`ChatContext` — accumulates all turns in a chat history (with an + optional sliding window). +- :class:`SimpleContext` — stateless, single-turn exchange (no prior history is + passed to the model). +- :class:`Compactor` — generic protocol for shrinking any ``Context`` subtype. + +The names :class:`Context`, :class:`ContextTurn`, :class:`CBlock`, and +:class:`Component` are re-exported from :mod:`mellea.core` for the convenience +of callers that import them via ``mellea.stdlib.context``. +""" + +from mellea.core import CBlock, Component, Context, ContextTurn +from mellea.stdlib.context.chat import ChatContext +from mellea.stdlib.context.compactor import ( + Compactor, + LLMSummarizeCompactor, + PinPredicate, + ThresholdCompactor, + WindowCompactor, + pin_nothing, + pin_system, + pin_system_and_initial_user, +) +from mellea.stdlib.context.simple import SimpleContext + +__all__ = [ + "CBlock", + "ChatContext", + "Compactor", + "Component", + "Context", + "ContextTurn", + "LLMSummarizeCompactor", + "PinPredicate", + "SimpleContext", + "ThresholdCompactor", + "WindowCompactor", + "pin_nothing", + "pin_system", + "pin_system_and_initial_user", +] diff --git a/mellea/stdlib/context/chat.py b/mellea/stdlib/context/chat.py new file mode 100644 index 000000000..0ac548460 --- /dev/null +++ b/mellea/stdlib/context/chat.py @@ -0,0 +1,101 @@ +"""Chat-style context with pluggable compaction.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from mellea.core import CBlock, Component, Context + +if TYPE_CHECKING: + from mellea.stdlib.context.compactor import Compactor + + +class ChatContext(Context): + """Chat context that accumulates turns and optionally compacts on each ``add``. + + By default the context performs **no compaction** — the full history is + retained. Compaction is opt-in: pass ``compactor=`` for a custom + strategy, or ``window_size=`` as sugar for ``WindowCompactor(size=...)``. + + Args: + compactor (Compactor | None): The compactor invoked on every ``add``. + ``None`` (the default) means no compaction; full history is kept. + window_size (int | None): Sugar that constructs a + :class:`WindowCompactor`. Mutually exclusive with ``compactor``. + ``None`` (the default) means no windowing. + """ + + def __init__( + self, *, compactor: Compactor | None = None, window_size: int | None = None + ) -> None: + """Initialize a ChatContext with an optional compactor.""" + if compactor is not None and window_size is not None: + raise ValueError( + "ChatContext: pass either `compactor` or `window_size`, not both." + ) + super().__init__() + if compactor is None and window_size is not None: + from mellea.stdlib.context.compactor import WindowCompactor + + self._compactor: Compactor | None = cast( + "Compactor", WindowCompactor(size=window_size) + ) + else: + self._compactor = compactor + + def add(self, c: Component | CBlock) -> ChatContext: + """Append ``c`` and run the compactor; return the resulting context. + + Args: + c (Component | CBlock): The component or content block to append. + + Returns: + ChatContext: A new ``ChatContext`` carrying the same compactor. + """ + new = ChatContext.from_previous(self, c) + new._compactor = self._compactor + if self._compactor is not None: + new = self._compactor.compact(new) + return new + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Return the components to forward to the model. + + Compaction is now applied at ``add`` time (Pattern 1), so this just + returns the linear history. ``None`` is returned when the underlying + history is non-linear. + + Returns: + list[Component | CBlock] | None: Ordered list of context entries. + """ + return self.as_list() + + +def _rebuild_chat_context( + components: list[Component | CBlock], *, compactor: Compactor | None = None +) -> ChatContext: + """Build a fresh ``ChatContext`` linked-list without triggering compaction. + + Used by ``WindowCompactor`` (and any future compactors that need to rebuild + a chat history). Manual node construction sidesteps ``ChatContext.add`` so + compactors don't recurse during their own work. + + Args: + components: Components to materialise as the new context, in order. + compactor: Compactor to attach to every node of the rebuilt context. + + Returns: + A new ``ChatContext`` whose linear history is exactly ``components``. + """ + ctx: ChatContext = ChatContext.__new__(ChatContext) + Context.__init__(ctx) + ctx._compactor = compactor + for c in components: + new: ChatContext = ChatContext.__new__(ChatContext) + new._previous = ctx + new._data = c + new._is_root = False + new._is_chat_context = ctx._is_chat_context + new._compactor = compactor + ctx = new + return ctx diff --git a/mellea/stdlib/context/compactor.py b/mellea/stdlib/context/compactor.py new file mode 100644 index 000000000..409bebaca --- /dev/null +++ b/mellea/stdlib/context/compactor.py @@ -0,0 +1,427 @@ +"""Generic ``Compactor`` protocol for shrinking a ``Context``. + +A ``Compactor`` returns a fresh, compacted copy of a context. Implementations +must never mutate the input — by convention, every alteration must produce a +new ``Context`` instance (the base class enforces this via ``from_previous``). + +Two usage patterns are supported: + +- **Pattern 1 (in ``Context.add``):** A subclass of ``Context`` holds a + ``Compactor`` and applies it whenever a new component is appended. +- **Pattern 2 (manual):** The caller invokes ``compactor.compact(ctx)`` + directly between turns, e.g. when compaction is exposed to the model as a + tool. + +See ``docs/rewrite/`` for full usage examples. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar + +from mellea.core import CBlock, Component, Context, ModelOutputThunk +from mellea.core.backend import Backend + +if TYPE_CHECKING: + from mellea.stdlib.context.chat import ChatContext + +T = TypeVar("T", bound=Context) + + +# --------------------------------------------------------------------------- # +# Pin predicates # +# --------------------------------------------------------------------------- # + +PinPredicate: TypeAlias = Callable[[list[Component | CBlock]], int] +"""A function that returns the index after the pinned prefix. + +Given the full ordered list of context components, a ``PinPredicate`` +returns the integer index ``idx`` such that ``components[:idx]`` is the +pinned prefix that the compactor must preserve, and ``components[idx:]`` +is the body that compaction acts on. + +The shape subsumes both "contiguous role-based prefix" (e.g. +:func:`pin_system`) and "find the first marker component" styles. +""" + + +def pin_nothing(components: list[Component | CBlock]) -> int: + """A :class:`PinPredicate` that pins nothing — pure body, no protected prefix.""" + return 0 + + +def pin_system(components: list[Component | CBlock]) -> int: + """Pin contiguous leading ``Message(role="system")`` components. + + Stops at the first non-system component. A system message that appears + later in the conversation is *not* pinned. + """ + from mellea.stdlib.components.chat import Message + + i = 0 + while i < len(components): + c = components[i] + if isinstance(c, Message) and c.role == "system": + i += 1 + else: + break + return i + + +def pin_system_and_initial_user(components: list[Component | CBlock]) -> int: + """Pin leading system messages PLUS the first user message that follows. + + Useful when the initial user prompt encodes the goal of the conversation + and should survive compaction along with any system instructions. + """ + from mellea.stdlib.components.chat import Message + + i = pin_system(components) + if i < len(components): + c = components[i] + if isinstance(c, Message) and c.role == "user": + i += 1 + return i + + +def _last_usage_tokens(ctx: Context) -> int | None: + """Return cumulative token count of the conversation as of the most recent turn. + + Walks ``ctx`` back-to-front looking for a ``ModelOutputThunk`` whose + ``generation.usage`` dict has been populated by a backend's + ``post_processing``. Returns ``total_tokens`` from that thunk — which, + for a chat backend, is ``prompt_tokens`` (size of the full conversation + sent to the model) plus ``completion_tokens`` (the model's reply). It + is therefore an estimate of the *current* conversation size, not just + one call's tokens in isolation. + + Falls back to ``prompt_tokens + completion_tokens`` when ``total_tokens`` + is missing. Returns ``None`` if no usable token count can be recovered + (typical before the first model call completes). + """ + for c in reversed(ctx.as_list()): + if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: + usage = c.generation.usage + total = usage.get("total_tokens") + if total is None: + pt = usage.get("prompt_tokens") or 0 + ct = usage.get("completion_tokens") or 0 + total = pt + ct + return total if total and total > 0 else None + return None + + +class Compactor(Protocol): + """Protocol for objects that compact a ``Context`` into a smaller copy. + + A compactor receives a context and returns a new context that retains only + the data the strategy considers worth keeping. Implementations MUST NOT + mutate the input context; they must return a fresh instance and copy over + any data that should be preserved. + + The protocol is generic in ``T`` (a ``Context`` subtype) so concrete + compactors can narrow their input/output type — for example a chat-only + compactor declares ``T = ChatContext``. + + The protocol is sync. Compactors that need to perform a backend call + (e.g. :class:`LLMSummarizeCompactor`) hide the async work behind the sync + method internally — see that class for the strategy used. + """ + + def compact(self, ctx: T, *, backend: Backend | None = None) -> T: + """Return a compacted copy of ``ctx``. + + Args: + ctx: The context to compact. Must be left unchanged. + backend: Optional backend. Generic compactors that only filter + components can ignore it. + + Returns: + A new context of the same type as ``ctx`` containing only the + retained data. + """ + ... + + +class WindowCompactor: + """Retains the last ``size`` body components of a ``ChatContext``. + + Uses ``pin_predicate`` to decide which leading components to preserve as + a protected prefix; the size limit is then applied to the body that + remains. The total context length after compaction is + ``len(prefix) + min(size, body_len)``. ``size`` counts only body + components. + + When the body is already at or below ``size``, ``ctx`` is returned + unchanged so the original linked-list and ``previous_node`` chain are + preserved. The result carries the same ``Compactor`` as the input so + subsequent ``add()`` calls keep compacting. + + Args: + size (int): Maximum number of most-recent body components to retain. + Pinned prefix components do NOT count against this budget. + ``size=0`` is a special case that drops the body entirely, + keeping only the pinned prefix. Negative values raise + :class:`ValueError`. + pin_predicate (PinPredicate): Function that decides the prefix + boundary. Defaults to :func:`pin_system`, which pins contiguous + leading ``Message(role="system")`` components. Pass + :func:`pin_nothing` for pure last-N behaviour or any other + ``PinPredicate`` (e.g. :func:`pin_system_and_initial_user`). + """ + + def __init__(self, *, size: int, pin_predicate: PinPredicate = pin_system) -> None: + """Initialize with the desired body window size and a pin predicate.""" + if size < 0: + raise ValueError("WindowCompactor size must be non-negative") + self.size = size + self.pin_predicate = pin_predicate + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Return a copy of ``ctx`` truncated to the last ``size`` body components. + + Args: + ctx: The chat context to compact. + backend: Unused by this strategy; accepted for protocol compatibility. + + Returns: + A new ``ChatContext`` whose history is the pinned prefix plus the + last ``size`` body components, carrying ``ctx``'s compactor. + Returns ``ctx`` itself if no truncation is required. + """ + full = ctx.as_list() + pin_end = self.pin_predicate(full) + body_len = len(full) - pin_end + + if body_len <= self.size: + return ctx + + from mellea.stdlib.context.chat import _rebuild_chat_context + + keep_body = full[pin_end:][-self.size :] if self.size > 0 else [] + compacted = full[:pin_end] + keep_body + return _rebuild_chat_context(compacted, compactor=ctx._compactor) + + +class ThresholdCompactor: + """Wraps an inner ``Compactor``, gating it on the conversation's token size. + + Despite the suffix, this class does not compact directly — it forwards + to ``inner.compact`` only when the conversation has grown larger than + ``threshold`` tokens; otherwise the input is returned unchanged. + + The token measurement is read off the most recent ``ModelOutputThunk``'s + ``generation.usage`` (via :func:`_last_usage_tokens`). Because chat + backends report ``prompt_tokens`` as the size of the full history they + were given as input, ``total_tokens = prompt_tokens + completion_tokens`` + on the latest thunk effectively measures *the size of the conversation + after that turn*, not just one isolated call. So the gate fires once + cumulative context size crosses ``threshold``. + + Caveats: + + - Components appended *after* the last thunk (e.g. a tool response in + the same turn) are not yet reflected in the reading — there is a + one-turn lag, negligible unless a single tool call adds a very large + payload. + - When the inner compactor shrinks the context, the *next* model call + will produce a smaller ``prompt_tokens``, so the gate will close + again. The threshold is not a high-water mark. + - Returns the input unchanged if no thunk with usage is found yet + (typical before the first model call completes). + + Args: + inner (Compactor): The compactor to invoke once the threshold is + exceeded. + threshold (int): Trigger the inner compactor when the conversation's + measured token size (most recent thunk's ``total_tokens``) + exceeds this value. ``0`` or negative disables the gate (the + inner is never invoked). + """ + + def __init__(self, inner: Compactor, *, threshold: int) -> None: + """Initialize with the inner compactor and token threshold.""" + self.inner = inner + self.threshold = threshold + + def compact(self, ctx: T, *, backend: Backend | None = None) -> T: + """Forward to ``inner.compact`` only when ``ctx`` exceeds the threshold. + + Args: + ctx: The context to potentially compact. + backend: Forwarded to the inner compactor. + + Returns: + ``inner.compact(ctx, backend=backend)`` when the recovered token + count exceeds ``self.threshold``, otherwise ``ctx`` unchanged. + """ + if self.threshold <= 0: + return ctx + tokens = _last_usage_tokens(ctx) + if tokens is None or tokens <= self.threshold: + return ctx + return self.inner.compact(ctx, backend=backend) + + +_DEFAULT_SUMMARY_PROMPT = ( + "You are summarizing a conversation to maintain context within token " + "limits.\n\n" + "Provide a concise summary that:\n" + "- Preserves specific facts, numbers, names, URLs, and key data\n" + "- Notes which tools were called and what results were obtained\n" + "- Highlights key decisions, findings, and unresolved issues\n" + "- Is structured clearly so the conversation can continue seamlessly\n\n" + "Conversation to summarize:\n{conversation}" +) + + +def _run_coro_blocking(coro): # type: ignore[no-untyped-def] + """Run an awaitable to completion regardless of the calling context. + + - Outside any event loop: ``asyncio.run(coro)``. + - Inside a running event loop: spawn a worker thread that runs a fresh + event loop with ``asyncio.run`` and block until it returns. + + Used by sync compactors that need to call async backend code (e.g. + :class:`LLMSummarizeCompactor`). Note that the second branch blocks the + calling thread (and, transitively, the running event loop) for the + duration of the coroutine — fine for a serial loop like ReACT, but not + suitable if other tasks need to make progress concurrently. + """ + import asyncio + import concurrent.futures + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + + +class LLMSummarizeCompactor: + """Replace old body components with an LLM-generated summary, keep last ``keep_n`` verbatim. + + Implements the sync :class:`Compactor` protocol. The compactor's body + needs to call the (async) backend; that async work is hidden inside the + sync ``compact`` method via :func:`_run_coro_blocking`. The pinned + prefix (chosen by ``pin_predicate``) is preserved unchanged; body + components older than the last ``keep_n`` are flattened into a single + ``Message(role="user")`` whose content is a structured summary; the + last ``keep_n`` body components are kept verbatim. + + Default ``pin_predicate`` is :func:`pin_nothing`, which means the entire + conversation participates in summarisation. For react workflows pass + :func:`mellea.stdlib.components.react.pin_react_initiator` so the goal + and tool registration survive untouched. + + Args: + keep_n (int): Number of recent body components to keep verbatim. + ``0`` summarises everything below the prefix. + pin_predicate (PinPredicate): Function that decides the prefix + boundary. Defaults to :func:`pin_nothing`. + prompt_template (str | None): Custom summary prompt. Must contain + the literal ``{conversation}`` placeholder, which is filled in + with a textual rendering of the body to summarise. Defaults to + a generic conversation-summary template. + """ + + def __init__( + self, + *, + keep_n: int = 5, + pin_predicate: PinPredicate = pin_nothing, + prompt_template: str | None = None, + ) -> None: + """Initialize with the recent-body window, pin predicate, and prompt.""" + if keep_n < 0: + raise ValueError("LLMSummarizeCompactor keep_n must be non-negative") + template = ( + prompt_template if prompt_template is not None else _DEFAULT_SUMMARY_PROMPT + ) + if "{conversation}" not in template: + raise ValueError( + "LLMSummarizeCompactor prompt_template must contain '{conversation}'" + ) + self.keep_n = keep_n + self.pin_predicate = pin_predicate + self.prompt_template = template + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Return a context with the prefix, an LLM summary, and recent body components. + + Args: + ctx: The chat context to compact. + backend: Backend used to generate the summary; required. + + Returns: + A new ``ChatContext`` containing the prefix, a single summary + ``Message`` produced by the backend, and the most-recent + ``keep_n`` body components verbatim. Returns ``ctx`` unchanged + when the body is already at or below ``keep_n`` in length. + + Raises: + ValueError: If ``backend`` is not provided. + """ + if backend is None: + raise ValueError("LLMSummarizeCompactor requires a `backend`") + + full = ctx.as_list() + pin_end = self.pin_predicate(full) + body = full[pin_end:] + if len(body) <= self.keep_n: + return ctx + + return _run_coro_blocking(self._async_compact(ctx, backend)) + + async def _async_compact(self, ctx: ChatContext, backend: Backend) -> ChatContext: + """Async core — renders the body, calls the backend, rebuilds the context.""" + # Lazy imports to keep this module free of mellea.stdlib.components dependencies. + from mellea.stdlib import functional as mfuncs + from mellea.stdlib.components.chat import Message, ToolMessage + from mellea.stdlib.context.chat import _rebuild_chat_context + from mellea.stdlib.context.simple import SimpleContext + + full = ctx.as_list() + pin_end = self.pin_predicate(full) + prefix = full[:pin_end] + body = full[pin_end:] + + old = body[: -self.keep_n] if self.keep_n > 0 else body + recent = body[-self.keep_n :] if self.keep_n > 0 else [] + + # Render `old` to text the LLM can consume. + lines: list[str] = [] + for c in old: + if isinstance(c, ToolMessage): + lines.append(f"tool ({c.name}): {c.content}") + elif isinstance(c, Message): + lines.append(f"{c.role}: {c.content}") + elif isinstance(c, ModelOutputThunk): + lines.append(f"assistant: {c.value}") + elif isinstance(c, CBlock): + lines.append(str(c)) + else: + lines.append(str(getattr(c, "content", c))) + + prompt = self.prompt_template.format(conversation="\n".join(lines)) + result, _ = await mfuncs.aact( + action=Message(role="user", content=prompt), + context=SimpleContext(), + backend=backend, + requirements=[], + strategy=None, + await_result=True, + ) + + summary_message = Message( + role="user", content=f"[CONTEXT SUMMARY]\n{result.value or ''}" + ) + compacted = [*prefix, summary_message, *recent] + return _rebuild_chat_context(compacted, compactor=ctx._compactor) diff --git a/mellea/stdlib/context/simple.py b/mellea/stdlib/context/simple.py new file mode 100644 index 000000000..81f3cfb23 --- /dev/null +++ b/mellea/stdlib/context/simple.py @@ -0,0 +1,32 @@ +"""Stateless single-turn context (no history is forwarded to the model).""" + +from __future__ import annotations + +from mellea.core import CBlock, Component, Context + + +class SimpleContext(Context): + """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" + + def add(self, c: Component | CBlock) -> SimpleContext: + """Add a new component or CBlock to the context and return the updated context. + + Args: + c (Component | CBlock): The component or content block to record. + + Returns: + SimpleContext: A new ``SimpleContext`` containing only the added entry; + prior history is not retained. + """ + return SimpleContext.from_previous(self, c) + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Return an empty list, since ``SimpleContext`` does not pass history to the model. + + Each call to the model is treated as a stateless, independent exchange. + No prior turns are forwarded. + + Returns: + list[Component | CBlock] | None: Always an empty list. + """ + return [] diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index c39338990..fcbc86db7 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -21,7 +21,7 @@ ReactInitiator, ReactThought, ) -from mellea.stdlib.context import ChatContext +from mellea.stdlib.context import ChatContext, Compactor async def react( @@ -36,6 +36,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + compactor: Compactor | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,6 +48,14 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + compactor: optional sync ``Compactor`` invoked once per turn after the + tool observation. Use this for strategies that should fire at turn + boundaries rather than on every component append (per-add + compaction is configured on ``context`` itself). Compose with + :func:`mellea.stdlib.components.react.pin_react_initiator` to + preserve the goal across compactions. Compactors that need to + call the backend (e.g. ``LLMSummarizeCompactor``) hide the async + work behind their sync ``compact`` method internally. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -129,4 +138,8 @@ async def react( step._underlying_value = str(tool_responses[0].content) return step, context + # Per-turn compaction hook (terminal turns skip this since `is_final` returned). + if compactor is not None: + context = compactor.compact(context, backend=backend) + raise RuntimeError(f"could not complete react loop in {loop_budget} iterations") diff --git a/test/stdlib/frameworks/test_react_framework.py b/test/stdlib/frameworks/test_react_framework.py index e121a91f5..8ae2d0b7b 100644 --- a/test/stdlib/frameworks/test_react_framework.py +++ b/test/stdlib/frameworks/test_react_framework.py @@ -231,5 +231,217 @@ async def test_react_rejects_non_chat_context(): await react(goal="g", context=Mock(), backend=Mock(), tools=None) +# --- compaction integration --- + + +def test_pin_react_initiator_finds_initiator(): + from mellea.stdlib.components.chat import Message + from mellea.stdlib.components.react import pin_react_initiator + + components = [ + Message("system", "sys"), + ReactInitiator("solve x", []), + Message("user", "step 1"), + ] + # Pinned prefix = system + initiator = first two indices. + assert pin_react_initiator(components) == 2 + + +def test_pin_react_initiator_returns_zero_when_absent(): + from mellea.stdlib.components.chat import Message + from mellea.stdlib.components.react import pin_react_initiator + + components = [Message("user", "a"), Message("assistant", "b")] + assert pin_react_initiator(components) == 0 + + +def test_react_summary_prompt_default(): + """Without a goal the prompt has no GOAL: line and contains {conversation}.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt() + assert "{conversation}" in prompt + assert "GOAL:" not in prompt + assert "research progress" in prompt + assert "search queries" in prompt + assert "dead ends" in prompt + + +def test_react_summary_prompt_with_goal(): + """Goal is interpolated and the prompt still has the {conversation} placeholder.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="find papers on context compaction") + assert "GOAL: find papers on context compaction" in prompt + assert "{conversation}" in prompt + + +def test_react_summary_prompt_escapes_braces_in_goal(): + """Braces in the goal must survive str.format() in LLMSummarizeCompactor.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="solve {x: 1, y: 2}") + # After str.format(conversation=...), the goal should appear with literal braces. + rendered = prompt.format(conversation="") + assert "GOAL: solve {x: 1, y: 2}" in rendered + assert "" in rendered + + +def test_react_summary_prompt_works_with_llm_summarize_compactor(): + """The factory's output passes LLMSummarizeCompactor's template validation.""" + from mellea.stdlib.components.react import react_summary_prompt + from mellea.stdlib.context import LLMSummarizeCompactor + + # Should not raise on construction (template contains {conversation}). + LLMSummarizeCompactor(prompt_template=react_summary_prompt(goal="g")) + LLMSummarizeCompactor(prompt_template=react_summary_prompt()) + LLMSummarizeCompactor( + prompt_template=react_summary_prompt(goal="g", max_tokens_hint=2000) + ) + + +def test_react_summary_prompt_max_tokens_hint_omitted_by_default(): + """Without a hint, the prompt is byte-identical to the un-hinted form.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="g") + prompt_explicit_none = react_summary_prompt(goal="g", max_tokens_hint=None) + assert prompt == prompt_explicit_none + assert "Be at most" not in prompt + assert "tokens (roughly" not in prompt + + +def test_react_summary_prompt_max_tokens_hint_injects_bullet(): + """Positive hint adds a bullet with token + word estimates.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="g", max_tokens_hint=2000) + # The bullet sits after "structured clearly" and before "Context to summarize:". + assert "- Be at most ~2000 tokens (roughly 1500 words)" in prompt + assert "Prioritize density" in prompt + # Ordering: structured-clearly bullet comes before the length bullet, + # length bullet comes before the conversation marker. + sc_idx = prompt.index("structured clearly") + bullet_idx = prompt.index("Be at most ~2000") + conv_idx = prompt.index("Context to summarize:") + assert sc_idx < bullet_idx < conv_idx + + +def test_react_summary_prompt_max_tokens_hint_zero_or_negative_omits_bullet(): + """Non-positive hint values are treated as no hint.""" + from mellea.stdlib.components.react import react_summary_prompt + + base = react_summary_prompt() + assert react_summary_prompt(max_tokens_hint=0) == base + assert react_summary_prompt(max_tokens_hint=-1) == base + + +def test_react_summary_prompt_max_tokens_hint_word_estimate_scales(): + """Word estimate uses the ~0.75 words/token heuristic (int truncation).""" + from mellea.stdlib.components.react import react_summary_prompt + + # 1000 tokens → 750 words; 4000 → 3000. + assert "~1000 tokens (roughly 750 words)" in react_summary_prompt( + max_tokens_hint=1000 + ) + assert "~4000 tokens (roughly 3000 words)" in react_summary_prompt( + max_tokens_hint=4000 + ) + + +@pytest.mark.asyncio +async def test_react_invokes_per_turn_compactor(): + """The ``compactor=`` hook runs once per turn after the tool observation.""" + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _final_answer_call("done"), + ] + ) + + calls = [] + + class RecordingCompactor: + def compact(self, ctx, *, backend=None): + calls.append(len(ctx.as_list())) + return ctx # no-op compaction; we just observe + + result, _ctx = await react( + goal="find info", + context=ChatContext(), + backend=backend, + tools=[search], + loop_budget=10, + compactor=RecordingCompactor(), + ) + + # Two non-terminal turns each invoke the compactor; the final turn skips it. + assert result.value == "done" + assert len(calls) == 2 + # Per-turn context monotonically grows in this trace. + assert calls[0] < calls[1] + + +@pytest.mark.asyncio +async def test_react_runs_llm_summarize_compactor(): + """LLMSummarizeCompactor.compact is sync (hides async internally), so react() + just calls it like any other sync Compactor. + """ + from mellea.stdlib.components.react import pin_react_initiator + from mellea.stdlib.context import LLMSummarizeCompactor + + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [_tool_call_turn("search", search, "step 1"), _final_answer_call("done")] + ) + + # keep_n large → no actual summarisation fires; the test verifies that + # the sync compact() method is callable from inside the async react() + # loop without exception. + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=LLMSummarizeCompactor(keep_n=1000, pin_predicate=pin_react_initiator), + ) + assert result.value == "done" + assert any(isinstance(c, ReactInitiator) for c in ctx.as_list()) + + +@pytest.mark.asyncio +async def test_react_compactor_can_actually_compact(): + """A real WindowCompactor wired in via the per-turn hook truncates context.""" + from mellea.stdlib.components.react import pin_react_initiator + from mellea.stdlib.context import WindowCompactor + + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _tool_call_turn("search", search, "step 3"), + _final_answer_call("done"), + ] + ) + + result, ctx = await react( + goal="find info", + # Permissive per-add window so we isolate the per-turn compactor's effect. + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator), + ) + + # The ReactInitiator must survive thanks to pin_react_initiator. + assert any(isinstance(c, ReactInitiator) for c in ctx.as_list()) + assert result.value == "done" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib/test_base_context.py b/test/stdlib/test_base_context.py index 2fccb11fd..83a4b42f8 100644 --- a/test/stdlib/test_base_context.py +++ b/test/stdlib/test_base_context.py @@ -4,8 +4,8 @@ from mellea.stdlib.context import ChatContext, SimpleContext -def context_construction(cls: type[Context]): - tree0 = cls() +def context_construction(cls: type[Context], **kwargs): + tree0 = cls(**kwargs) tree1 = tree0.add(CBlock("abc")) assert tree1.previous_node == tree0 @@ -15,11 +15,14 @@ def context_construction(cls: type[Context]): def test_context_construction(): context_construction(SimpleContext) + # ChatContext defaults to WindowCompactor(5); a single add stays well under + # the window so the linked-list shape is identical to the pre-compaction + # behaviour. context_construction(ChatContext) -def large_context_construction(cls: type[Context]): - root = cls() +def large_context_construction(cls: type[Context], **kwargs): + root = cls(**kwargs) full_graph: Context = root for i in range(1000): @@ -31,7 +34,9 @@ def large_context_construction(cls: type[Context]): def test_large_context_construction(): large_context_construction(SimpleContext) - large_context_construction(ChatContext) + # ChatContext now applies real compaction at add() time; pass a window + # large enough that all 1000 components survive. + large_context_construction(ChatContext, window_size=2000) def test_render_view_for_simple_context(): @@ -48,7 +53,9 @@ def test_render_view_for_chat_context(): ctx = ChatContext(window_size=3) for i in range(5): ctx = ctx.add(CBlock(f"a {i}")) - assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" + # Compaction is now applied at add() time, so as_list and view_for_generation + # both reflect the sliding window of 3. + assert len(ctx.as_list()) == 3, "WindowCompactor(3) should keep 3 items" assert len(ctx.view_for_generation()) == 3, "Render size should be 3" # type: ignore diff --git a/test/stdlib/test_compactor.py b/test/stdlib/test_compactor.py new file mode 100644 index 000000000..bdb601049 --- /dev/null +++ b/test/stdlib/test_compactor.py @@ -0,0 +1,492 @@ +"""Tests for the ``Compactor`` protocol, ``WindowCompactor``, ``ThresholdCompactor``.""" + +from __future__ import annotations + +import pytest + +from mellea.core.base import ModelOutputThunk +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ( + ChatContext, + Compactor, + LLMSummarizeCompactor, + PinPredicate, + ThresholdCompactor, + WindowCompactor, + pin_nothing, + pin_system, + pin_system_and_initial_user, +) +from mellea.stdlib.context.compactor import _last_usage_tokens + + +def _msg(i: int) -> Message: + return Message(role="user", content=f"m{i}") + + +def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict.""" + mot = ModelOutputThunk(value=value) + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +class TestChatContextDefaults: + def test_default_has_no_compactor(self): + # Compaction is opt-in: bare ChatContext() retains full history. + ctx = ChatContext() + assert ctx._compactor is None + + def test_default_keeps_full_history(self): + ctx = ChatContext() + for i in range(20): + ctx = ctx.add(_msg(i)) + assert len(ctx.as_list()) == 20 + + def test_window_size_arg_constructs_window_compactor(self): + ctx = ChatContext(window_size=3) + assert isinstance(ctx._compactor, WindowCompactor) + assert ctx._compactor.size == 3 + + def test_passing_both_args_raises(self): + with pytest.raises(ValueError): + ChatContext(compactor=WindowCompactor(size=2), window_size=3) + + def test_explicit_compactor_overrides_default(self): + comp = WindowCompactor(size=2) + ctx = ChatContext(compactor=comp) + assert ctx._compactor is comp + + +class TestWindowCompactor: + def test_compact_keeps_last_n(self): + ctx = ChatContext(window_size=3) + for i in range(7): + ctx = ctx.add(_msg(i)) + items = ctx.as_list() + assert len(items) == 3 + assert [m.content for m in items] == ["m4", "m5", "m6"] + + def test_compact_does_not_mutate_original(self): + # Build with a permissive window so all 3 items are retained, then + # apply a tighter compactor manually (Pattern 2). + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_msg(0)) + ctx = ctx.add(_msg(1)) + ctx = ctx.add(_msg(2)) + before_compact = [m.content for m in ctx.as_list()] + compacted = WindowCompactor(size=2).compact(ctx) + # original unchanged + assert [m.content for m in ctx.as_list()] == before_compact + # compacted is shorter and a different object + assert compacted is not ctx + assert len(compacted.as_list()) == 2 + + def test_compact_preserves_compactor_on_result(self): + comp = WindowCompactor(size=2) + ctx = ChatContext(compactor=comp) + ctx = ctx.add(_msg(0)).add(_msg(1)).add(_msg(2)) + # subsequent adds keep using the same compactor + ctx = ctx.add(_msg(3)) + assert ctx._compactor is comp + assert len(ctx.as_list()) == 2 + + def test_view_for_generation_no_double_truncation(self): + ctx = ChatContext(window_size=3) + for i in range(7): + ctx = ctx.add(_msg(i)) + # add() already compacted; view should match the linear history exactly + view = ctx.view_for_generation() + assert view is not None + assert [m.content for m in view] == [m.content for m in ctx.as_list()] + + def test_negative_size_raises(self): + with pytest.raises(ValueError): + WindowCompactor(size=-1) + + def test_size_zero_clears_body(self): + # Regression: `[-0:]` evaluates to `[0:]` in Python, which would keep + # the entire body instead of nothing. size=0 must keep zero body items. + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=0).compact(ctx) + assert result.as_list() == [] + + def test_size_zero_keeps_pinned_prefix(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(3): + ctx = ctx.add(_msg(i)) + # Default pin_predicate=pin_system → system stays, body cleared. + result = WindowCompactor(size=0).compact(ctx) + items = result.as_list() + assert len(items) == 1 + assert items[0].content == "sys" + + def test_pins_leading_system_message(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="You are helpful.")) + for i in range(5): + ctx = ctx.add(_msg(i)) + # Apply WindowCompactor(size=2) manually — keep system + last 2 body. + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert len(items) == 3 + assert isinstance(items[0], Message) and items[0].role == "system" + assert [m.content for m in items[1:]] == ["m3", "m4"] + + def test_pins_multiple_leading_system_messages(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys1")) + ctx = ctx.add(Message(role="system", content="sys2")) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert [m.content for m in items[:2]] == ["sys1", "sys2"] + assert [m.content for m in items[2:]] == ["m3", "m4"] + + def test_does_not_pin_non_contiguous_system(self): + # System message in the middle is NOT pinned — only the contiguous prefix. + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_msg(0)) # body starts here + ctx = ctx.add(Message(role="system", content="late-sys")) + for i in range(1, 6): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert len(items) == 2 + assert "late-sys" not in [getattr(m, "content", None) for m in items] + + def test_no_system_message_pure_last_n(self): + # Without any system prefix, behaviour is pure last-N (matches Phase 2 semantics). + ctx = ChatContext(window_size=10_000) + for i in range(7): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=3).compact(ctx) + items = result.as_list() + assert [m.content for m in items] == ["m4", "m5", "m6"] + + +class TestCompactorProtocol: + def test_user_class_satisfies_protocol(self): + """A plain class with the right method should be a Compactor.""" + + class Identity: + def compact(self, ctx, *, backend=None): + return ctx + + # structural subtyping check — at runtime this is just isinstance against Protocol + # which requires `runtime_checkable` to actually work; instead assert duck-typing. + c = Identity() + ctx = ChatContext(compactor=c) + ctx = ctx.add(_msg(0)) + # Identity returns ctx unchanged, so we still see m0 + assert [m.content for m in ctx.as_list()] == ["m0"] + + def test_pattern_2_manual_compaction(self): + """Pattern 2: caller invokes compactor.compact() directly.""" + comp = WindowCompactor(size=2) + # context with no auto-compaction would be tricky to construct under the + # new defaults; instead use a window large enough that auto-compaction + # never fires, then apply comp manually. + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + assert len(ctx.as_list()) == 5 + ctx2 = comp.compact(ctx) + assert len(ctx2.as_list()) == 2 + # original still untouched + assert len(ctx.as_list()) == 5 + + +class TestLastUsageTokens: + def test_no_thunk_returns_none(self): + ctx = ChatContext(window_size=100).add(_msg(0)) + assert _last_usage_tokens(ctx) is None + + def test_thunk_without_usage_returns_none(self): + ctx = ChatContext(window_size=100).add(_msg(0)).add(ModelOutputThunk(value="x")) + assert _last_usage_tokens(ctx) is None + + def test_reads_total_tokens(self): + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(150)) + assert _last_usage_tokens(ctx) == 150 + + def test_falls_back_to_prompt_plus_completion(self): + mot = ModelOutputThunk(value="x") + mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} + ctx = ChatContext(window_size=100).add(_msg(0)).add(mot) + assert _last_usage_tokens(ctx) == 60 + + def test_uses_most_recent_thunk(self): + ctx = ( + ChatContext(window_size=100).add(_thunk(100)).add(_msg(0)).add(_thunk(500)) + ) + assert _last_usage_tokens(ctx) == 500 + + +class TestThresholdCompactor: + def test_below_threshold_returns_input(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=1000) + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(50)) + # 5 components but inner not invoked because token count (50) <= threshold (1000) + for i in range(1, 6): + ctx = ctx.add(_msg(i)) + result = gated.compact(ctx) + assert result is ctx + + def test_above_threshold_runs_inner(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=100) + # Build a context with the last thunk reporting >threshold tokens. + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + ctx = ctx.add(_thunk(500)) + result = gated.compact(ctx) + # Inner was invoked → only last 2 components retained. + assert len(result.as_list()) == 2 + + def test_no_thunk_no_compaction(self): + """No thunk means no usage info — gate stays closed.""" + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=100) + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = gated.compact(ctx) + assert result is ctx + + def test_zero_threshold_disables_gate(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=0) + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(10_000)) + result = gated.compact(ctx) + # Threshold 0 means "never trigger" — input passes through. + assert result is ctx + + +class TestPinPredicates: + def test_pin_nothing(self): + assert pin_nothing([_msg(0), _msg(1)]) == 0 + assert pin_nothing([]) == 0 + + def test_pin_system_zero_when_no_system(self): + assert pin_system([_msg(0), _msg(1)]) == 0 + + def test_pin_system_counts_contiguous(self): + components = [ + Message(role="system", content="s1"), + Message(role="system", content="s2"), + _msg(0), + Message(role="system", content="late-s"), # not pinned — non-contiguous + ] + assert pin_system(components) == 2 + + def test_pin_system_and_initial_user_with_both(self): + components = [ + Message(role="system", content="s1"), + Message(role="user", content="goal"), + Message(role="assistant", content="ack"), + ] + assert pin_system_and_initial_user(components) == 2 + + def test_pin_system_and_initial_user_no_user(self): + components = [ + Message(role="system", content="s1"), + Message(role="assistant", content="x"), + ] + # First non-system is "assistant", not "user" — not pinned beyond system. + assert pin_system_and_initial_user(components) == 1 + + def test_pin_system_and_initial_user_user_only(self): + components = [ + Message(role="user", content="goal"), + Message(role="assistant", content="ok"), + ] + assert pin_system_and_initial_user(components) == 1 + + +class TestWindowCompactorPredicate: + def test_pin_nothing_pure_last_n(self): + comp = WindowCompactor(size=2, pin_predicate=pin_nothing) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + assert len(items) == 2 + # System is dropped because predicate returned 0. + assert "sys" not in [getattr(m, "content", None) for m in items] + + def test_pin_system_and_initial_user_protects_first_user(self): + comp = WindowCompactor(size=2, pin_predicate=pin_system_and_initial_user) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + ctx = ctx.add(Message(role="user", content="goal")) + for i in range(6): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + # prefix (sys + goal) + last 2 body = 4 + assert len(items) == 4 + assert items[0].content == "sys" + assert items[1].content == "goal" + + def test_custom_predicate(self): + # Predicate that pins the first 3 components unconditionally. + def pin_first_3(components): + return min(3, len(components)) + + comp = WindowCompactor(size=2, pin_predicate=pin_first_3) + ctx = ChatContext(window_size=10_000) + for i in range(8): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + # prefix (m0, m1, m2) + last 2 of body (m6, m7) = 5 + assert [m.content for m in items] == ["m0", "m1", "m2", "m6", "m7"] + + +# --------------------------------------------------------------------------- # +# LLMSummarizeCompactor # +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def scripted_summary_backend(): + """Lazy-built fake backend that returns a fixed summary on each generate call.""" + from collections.abc import Sequence + + from mellea.core.backend import Backend, BaseModelSubclass + from mellea.core.base import C, GenerateLog + + class FakeBackend(Backend): + def __init__(self, summary: str = "SUMMARY-OF-OLD") -> None: + self.summary = summary + self.calls = 0 + + async def _generate_from_context( + self, + action, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + self.calls += 1 + mot = ModelOutputThunk(value=self.summary) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + raise NotImplementedError + + return FakeBackend() + + +class TestLLMSummarizeCompactor: + def test_negative_keep_n_raises(self): + with pytest.raises(ValueError): + LLMSummarizeCompactor(keep_n=-1) + + def test_prompt_template_must_have_placeholder(self): + with pytest.raises(ValueError, match="conversation"): + LLMSummarizeCompactor(prompt_template="no placeholder here") + + def test_compact_is_sync(self): + import inspect + + comp = LLMSummarizeCompactor() + # Sync from the outside even though the implementation calls async backend code. + assert not inspect.iscoroutinefunction(comp.compact) + + def test_raises_without_backend(self): + comp = LLMSummarizeCompactor() + ctx = ChatContext(window_size=10_000) + for i in range(3): + ctx = ctx.add(_msg(i)) + with pytest.raises(ValueError, match="backend"): + comp.compact(ctx) + + def test_short_body_is_noop(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=5) + ctx = ChatContext(window_size=10_000) + for i in range(3): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + # body length (3) <= keep_n (5) → no-op, backend not called + assert result is ctx + assert scripted_summary_backend.calls == 0 + + def test_summarises_old_keeps_recent(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=2) + ctx = ChatContext(window_size=10_000) + for i in range(6): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + # summary (1) + last 2 verbatim = 3 + assert len(items) == 3 + assert "[CONTEXT SUMMARY]" in items[0].content + assert items[1].content == "m4" + assert items[2].content == "m5" + assert scripted_summary_backend.calls == 1 + + def test_pin_predicate_preserves_prefix(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=1, pin_predicate=pin_system) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(4): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + # system (pinned) + summary + last 1 verbatim = 3 + assert items[0].role == "system" + assert items[0].content == "sys" + assert "[CONTEXT SUMMARY]" in items[1].content + assert items[2].content == "m3" + + def test_does_not_mutate_original(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + before = [m.content for m in ctx.as_list()] + comp.compact(ctx, backend=scripted_summary_backend) + assert [m.content for m in ctx.as_list()] == before + + def test_satisfies_compactor_protocol(self): + comp: Compactor = LLMSummarizeCompactor() + # Just a typing-level check that the assignment is accepted. + assert callable(comp.compact) + + @pytest.mark.asyncio + async def test_works_inside_running_event_loop(self, scripted_summary_backend): + """compact() is callable from within an async function — uses worker thread.""" + comp = LLMSummarizeCompactor(keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + # No await: this is a sync call from inside an async test. + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + assert "[CONTEXT SUMMARY]" in items[0].content + assert items[1].content == "m3"