Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions docs/examples/context/README.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,59 @@
# Context Examples

This directory contains examples demonstrating how to work with Mellea's context system, particularly when using sampling strategies and validation.
This directory contains examples demonstrating how to work with Mellea's context system: inspecting per-attempt contexts produced by sampling strategies, and shrinking contexts with the `Compactor` protocol.

## Files

### contexts_with_sampling.py

Shows how to retrieve and inspect context information when using sampling strategies and validation.

**Key Features:**

- Using `RejectionSamplingStrategy` with requirements
- Accessing `SamplingResult` objects to inspect generation attempts
- Retrieving context for different generation attempts
- Examining validation contexts for each requirement
- Understanding the context tree structure

**Usage:**
```bash

```
python docs/examples/context/contexts_with_sampling.py
```

### window_compactor.py

`WindowCompactor` — opt-in by passing `compactor=` (or the `window_size=` sugar). Demonstrates system-prefix pinning, `pin_system_and_initial_user`, `pin_nothing` (pure last-N), and `size=0` to clear the body.

### threshold_compactor.py

`ThresholdCompactor` — gate an inner compactor on the conversation's cumulative token size. The reading is taken from the most recent `ModelOutputThunk`'s `total_tokens`, which for a chat backend equals `prompt_tokens` (full conversation history sent to the model) + `completion_tokens` (reply). The gate fires once the running conversation size crosses the threshold; once compaction shrinks the context, the next call produces a smaller reading and the gate closes again.

### custom_compactor.py

Implement the `Compactor` protocol with a plain class (no inheritance). Shows Pattern 1 (wired into `ChatContext`) and Pattern 2 (manual `compact()` call).

### react_compaction.py

Compose the ReACT loop with a sync `Compactor`. Two integration points:

- **Per-add** — wire a `Compactor` onto the `ChatContext` so it runs every time `react()` appends a Message, ToolMessage, or thunk.
- **Per-turn** — pass `compactor=` to `react()`; it fires once per ReACT iteration after the tool observation.

`LLMSummarizeCompactor` is also a sync `Compactor` — it hides the async backend call internally (worker thread when called from an already-running event loop) so callers don't have to think about sync vs async.

Use `pin_react_initiator` (from `mellea.stdlib.components.react`) as the predicate so the goal and tool registration survive compaction.

## Concepts Demonstrated

- **Sampling Results**: Working with `SamplingResult` objects
- **Context Inspection**: Accessing generation and validation contexts
- **Multiple Attempts**: Examining different generation attempts
- **Context Trees**: Understanding how contexts link together
- **Validation Context**: Inspecting how requirements were evaluated
- **Compaction Protocol**: Sync `Compactor` for per-`add()` shrinking
- **Pin Predicates**: Auto-protect leading system messages or the user's initial prompt during compaction

## Key APIs

Expand All @@ -48,8 +76,23 @@ gen_ctx.previous_node.node_data
val_ctx.node_data
```

```python
# Wire a compactor into a ChatContext (Pattern 1 — runs on every add())
from mellea.stdlib.context import ChatContext, WindowCompactor, ThresholdCompactor

ctx = ChatContext(compactor=WindowCompactor(size=5)) # default: pin_system
ctx = ChatContext(window_size=5) # sugar for the line above
ctx = ChatContext(
compactor=ThresholdCompactor(WindowCompactor(size=5), threshold=8000),
)

# Manual compaction (Pattern 2)
ctx = WindowCompactor(size=0).compact(ctx) # drop body, keep pinned prefix
```

## Related Documentation

- See `mellea/stdlib/context.py` for context implementation
- See `mellea/stdlib/context/` for context and compactor implementations
- See `mellea/stdlib/sampling/` for sampling strategies
- See `docs/dev/spans.md` for context architecture details
- See `mellea/stdlib/frameworks/react.py` for the ReACT loop
- See `docs/dev/spans.md` for context architecture details
63 changes: 63 additions & 0 deletions docs/examples/context/custom_compactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# pytest: unit
"""Implementing the Compactor protocol — anything with ``compact()`` works.

The protocol is structurally typed: a class with a ``compact(ctx, *,
backend=None) -> ChatContext`` method is a valid Compactor. No
inheritance is required.
"""

from mellea.stdlib.components.chat import Message
from mellea.stdlib.context import ChatContext, Compactor
from mellea.stdlib.context.chat import _rebuild_chat_context


class TruncateOldest:
"""Drop only the very first body component each call.

Demonstrates the smallest possible Compactor implementation. Pattern
1 (wired into ``ChatContext``) means each ``add()`` removes the
oldest item then appends — net result: the context never grows.
"""

def compact(self, ctx, *, backend=None):
items = ctx.as_list()
if len(items) <= 1:
return ctx
return _rebuild_chat_context(items[1:], compactor=ctx._compactor)


def pattern_1_wired_into_context():
"""Pattern 1: compactor lives on the context, runs in ``add()``."""
ctx = ChatContext(compactor=TruncateOldest())
for i in range(4):
ctx = ctx.add(Message("user", f"msg {i}"))
return [m.content for m in ctx.as_list()]
# → ['msg 3'] (oldest dropped before each append)


def pattern_2_manual_call():
"""Pattern 2: caller invokes ``compact()`` directly between turns."""
ctx = ChatContext(window_size=10_000) # permissive — no auto-compaction
for i in range(5):
ctx = ctx.add(Message("user", f"msg {i}"))
truncated = TruncateOldest().compact(ctx)
return [m.content for m in truncated.as_list()]


def structural_typing_check():
"""The Compactor protocol is satisfied structurally, no inheritance."""
c: Compactor = TruncateOldest() # mypy-checked Protocol assignment
return type(c).__name__


if __name__ == "__main__":
for fn in [pattern_1_wired_into_context, pattern_2_manual_call]:
print(f"--- {fn.__name__} ---")
print(fn())
print(f"structural typing: {structural_typing_check()} satisfies Compactor")


def test_custom_compactor_examples():
assert pattern_1_wired_into_context() == ["msg 3"]
assert pattern_2_manual_call() == ["msg 1", "msg 2", "msg 3", "msg 4"]
assert structural_typing_check() == "TruncateOldest"
235 changes: 235 additions & 0 deletions docs/examples/context/react_compaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# pytest: unit
"""Compose the ReACT loop with a sync `Compactor`.

Two integration points are available, and they're complementary:

1. **Per-add** — the `ChatContext`'s own compactor runs every time the
ReACT loop appends a Message, ToolMessage, or thunk. This is fine
for cheap strategies like `WindowCompactor`.
2. **Per-turn** — pass `compactor=` to ``react(...)`` to invoke a
compactor once per ReACT iteration after the tool observation. Use
it for heavier strategies that should fire at turn boundaries
instead of on every component append.

In both cases use ``pin_react_initiator`` (from
``mellea.stdlib.components.react``) so the goal and tool registration
survive compaction.

This example exercises the wiring end-to-end against a fake backend so
no LLM is required.
"""

from __future__ import annotations

import asyncio
from collections.abc import Sequence
from dataclasses import dataclass

from mellea.backends.tools import MelleaTool
from mellea.core.backend import Backend, BaseModelSubclass
from mellea.core.base import (
C,
CBlock,
Component,
Context,
GenerateLog,
ModelOutputThunk,
ModelToolCall,
)
from mellea.stdlib.components.react import (
MELLEA_FINALIZER_TOOL,
ReactInitiator,
_mellea_finalize_tool,
pin_react_initiator,
)
from mellea.stdlib.context import ChatContext, WindowCompactor
from mellea.stdlib.frameworks.react import react

# --------------------------------------------------------------------------- #
# Fake backend so the example runs without an LLM #
# --------------------------------------------------------------------------- #


@dataclass
class _ScriptedTurn:
value: str
tool_calls: dict[str, ModelToolCall] | None = None


class ScriptedBackend(Backend):
"""Returns pre-scripted responses; no real model is called."""

def __init__(self, script: list[_ScriptedTurn]) -> None:
self._script = iter(script)

async def _generate_from_context(
self,
action: Component[C] | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> tuple[ModelOutputThunk[C], Context]:
turn = next(self._script)
mot: ModelOutputThunk = ModelOutputThunk(
value=turn.value, tool_calls=turn.tool_calls
)
mot._generate_log = GenerateLog(is_final_result=True)
return mot, ctx.add(action).add(mot)

async def generate_from_raw(
self,
actions: Sequence[Component[C] | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
raise NotImplementedError


def _tool(name: str, return_value: str = "ok") -> MelleaTool:
def _fn() -> str:
return return_value

return MelleaTool.from_callable(_fn, name=name)


def _tool_call(tool_name: str, tool: MelleaTool, thought: str) -> _ScriptedTurn:
tc = ModelToolCall(name=tool_name, func=tool, args={})
return _ScriptedTurn(value=thought, tool_calls={tool_name: tc})


def _final(answer: str) -> _ScriptedTurn:
finalizer = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL)
tc = ModelToolCall(
name=MELLEA_FINALIZER_TOOL, func=finalizer, args={"answer": answer}
)
return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc})


# --------------------------------------------------------------------------- #
# Pattern A — per-add compaction wired into the ChatContext #
# --------------------------------------------------------------------------- #


async def per_add_compaction():
"""A `WindowCompactor(pin_react_initiator)` on the ChatContext compacts
on every ``add()`` — Messages, ToolMessages, thunks. The ReactInitiator
stays pinned across the whole loop.
"""
search = _tool("search")
backend = ScriptedBackend(
[
_tool_call("search", search, "step 1"),
_tool_call("search", search, "step 2"),
_tool_call("search", search, "step 3"),
_final("done"),
]
)
ctx = ChatContext(
compactor=WindowCompactor(size=3, pin_predicate=pin_react_initiator)
)
result, ctx = await react(
goal="find info", context=ctx, backend=backend, tools=[search], loop_budget=10
)
return (
result.value,
any(isinstance(c, ReactInitiator) for c in ctx.as_list()),
len(ctx.as_list()),
)


# --------------------------------------------------------------------------- #
# Pattern B — per-turn compaction passed to react() #
# --------------------------------------------------------------------------- #


async def per_turn_compaction():
"""Pass ``compactor=`` to ``react`` for once-per-turn invocation.

Use a permissive ``ChatContext`` (large window) so the per-add path is
effectively disabled — only the per-turn hook drives compaction.
"""
search = _tool("search")
backend = ScriptedBackend(
[
_tool_call("search", search, "step 1"),
_tool_call("search", search, "step 2"),
_tool_call("search", search, "step 3"),
_final("done"),
]
)
result, ctx = await react(
goal="find info",
context=ChatContext(window_size=10_000),
backend=backend,
tools=[search],
loop_budget=10,
compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator),
)
return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list()))


# --------------------------------------------------------------------------- #
# Pattern C — LLM-driven summarisation #
# --------------------------------------------------------------------------- #


async def llm_summarize_compaction():
"""Wire :class:`LLMSummarizeCompactor` into ``react()``.

``LLMSummarizeCompactor`` implements the sync :class:`Compactor`
protocol — its ``compact`` method internally orchestrates the async
backend call (running it on a worker thread when invoked from inside
an event loop). From ``react()``'s perspective it's just another
sync compactor.

To keep the scripted backend simple, this example sets ``keep_n``
large enough that summarisation never fires (no LLM call is needed).
Real usage would pair it with ``ThresholdCompactor`` so it only
activates once the conversation crosses a token budget. See
``TestLLMSummarizeCompactor`` in ``test/stdlib/test_compactor.py`` for
unit tests that exercise the actual summary path.
"""
from mellea.stdlib.context import LLMSummarizeCompactor

search = _tool("search")
backend = ScriptedBackend([_tool_call("search", search, "step 1"), _final("done")])
result, ctx = await react(
goal="find info",
context=ChatContext(window_size=10_000),
backend=backend,
tools=[search],
loop_budget=10,
# keep_n=1000 → no summarisation triggers in this short script;
# the example just shows the async compactor is wired correctly.
compactor=LLMSummarizeCompactor(keep_n=1000, pin_predicate=pin_react_initiator),
)
return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list()))


if __name__ == "__main__":
print(f"per_add_compaction: {asyncio.run(per_add_compaction())}")
print(f"per_turn_compaction: {asyncio.run(per_turn_compaction())}")
print(f"llm_summarize_compact: {asyncio.run(llm_summarize_compaction())}")


def test_per_add_compaction():
answer, has_initiator, _length = asyncio.run(per_add_compaction())
assert answer == "done"
assert has_initiator


def test_per_turn_compaction():
answer, has_initiator = asyncio.run(per_turn_compaction())
assert answer == "done"
assert has_initiator


def test_llm_summarize_compaction():
answer, has_initiator = asyncio.run(llm_summarize_compaction())
assert answer == "done"
assert has_initiator
Loading
Loading