diff --git a/README.md b/README.md index d4d6a61b..137ee4d8 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ Some examples require extra dependencies. See each sample's directory for specif * [pydantic_converter](pydantic_converter) - Data converter for using Pydantic models. * [schedules](schedules) - Demonstrates a Workflow Execution that occurs according to a schedule. * [sentry](sentry) - Report errors to Sentry. +* [tool_registry_incident_triage](tool_registry_incident_triage) - LLM-driven incident triage activity using `temporalio.contrib.tool_registry`. Demonstrates `AgenticSession`, MCP HTTP integration, human-in-the-loop, and a testable activity refactor. * [trio_async](trio_async) - Use asyncio Temporal in Trio-based environments. * [updatable_timer](updatable_timer) - A timer that can be updated while sleeping. * [worker_specific_task_queues](worker_specific_task_queues) - Use unique task queues to ensure activities run on specific workers. diff --git a/tool_registry_incident_triage/README.md b/tool_registry_incident_triage/README.md new file mode 100644 index 00000000..67e9a0f7 --- /dev/null +++ b/tool_registry_incident_triage/README.md @@ -0,0 +1,41 @@ +# Python: incident-triage tool-registry sample + +Demonstrates `temporalio.contrib.tool_registry` end-to-end: long-running `AgenticSession` activity, MCP HTTP integration, human-in-the-loop via companion workflow, and a testable activity refactor. + +## What's here + +| File | Purpose | +|---|---| +| `triage_types.py` | `AlertPayload`, `TriageResult`, `ApprovalRequest`, `ApprovalResponse` records. | +| `triage_activity.py` | The activity. Defines `TriageDeps` (record of I/O callables), `build_triage_registry(alert, session, deps)` returning `(registry, get_result)`, and the activity entrypoint that wires production deps. | +| `triage_workflow.py` | Workflow that schedules the activity with `agentic` timeout profile. | +| `approval_workflow.py` | Companion HITL workflow: deterministic ID per alert, two signals (request/decision), one query (pending). | +| `worker.py` | Worker registration. | +| `client.py` | Demo client to start a workflow. | +| `tests/test_triage_activity.py` | Unit tests demonstrating `MockProvider` + `TriageDeps` pattern. Run: `pytest tests/`. | + +## Run + +```bash +# 1. Run a Temporal dev server (separate terminal) +temporal server start-dev + +# 2. Set up env +export ANTHROPIC_API_KEY=sk-ant-... +export PROM_MCP=http://localhost:7070/mcp +export K8S_MCP=http://localhost:7071/mcp + +# 3. Start the worker +python worker.py + +# 4. Start a workflow +python client.py +``` + +Tests don't need a Temporal server or API key. + +## Requires + +- `temporalio` with `tool_registry` contrib (currently the `feat/tool-registry` branch — install from source or wait for the next release). +- `anthropic` Python SDK (peer dep). +- `httpx` for MCP HTTP calls. diff --git a/tool_registry_incident_triage/__init__.py b/tool_registry_incident_triage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tool_registry_incident_triage/approval_workflow.py b/tool_registry_incident_triage/approval_workflow.py new file mode 100644 index 00000000..947ef7a2 --- /dev/null +++ b/tool_registry_incident_triage/approval_workflow.py @@ -0,0 +1,45 @@ +"""Companion HITL workflow. + +The triage agent's request_human_approval tool calls signal_with_start +against a deterministic ID per alert group. This workflow stores the latest +agent request, exposes it as a query, and returns the operator's decision. + +Same shape as the TypeScript reference's approval workflow (workers/typescript/ +workflows/approval.ts) — deterministic ID, request signal, decision signal, +pending-approval query, two condition() blocks. +""" +from __future__ import annotations + +from temporalio import workflow + +from triage_types import ApprovalRequest, ApprovalResponse + + +@workflow.defn(name="approvalWorkflow") +class ApprovalWorkflow: + def __init__(self) -> None: + self._request: ApprovalRequest | None = None + self._response: ApprovalResponse | None = None + + @workflow.run + async def run(self, _key: str) -> ApprovalResponse: + # Block until the agent signals a request AND the operator responds. + await workflow.wait_condition(lambda: self._request is not None) + await workflow.wait_condition(lambda: self._response is not None) + assert self._response is not None + return self._response + + @workflow.signal(name="approval-request") + def request(self, req: ApprovalRequest) -> None: + # LLM retry: re-attached signals overwrite the request. Operator only + # ever sees the latest version, since the agent may refine its ask + # across retries. + self._request = req + + @workflow.signal(name="approval-decision") + def decide(self, res: ApprovalResponse) -> None: + self._response = res + + @workflow.query(name="pending-approval") + def pending(self) -> ApprovalRequest | None: + return self._request diff --git a/tool_registry_incident_triage/client.py b/tool_registry_incident_triage/client.py new file mode 100644 index 00000000..08ae19f2 --- /dev/null +++ b/tool_registry_incident_triage/client.py @@ -0,0 +1,118 @@ +"""Client CLI for the Python triage workers. + +Usage: + python -m client pending # list pending approval workflows + python -m client approve + python -m client reject + python -m client trigger # post a synthetic alert (skips webhook) +""" +from __future__ import annotations + +import asyncio +import os +import sys +from datetime import datetime, timezone + +from temporalio.client import Client + +from approval_workflow import ApprovalWorkflow +from triage_workflow import IncidentTriageWorkflow +from triage_types import AlertPayload, ApprovalResponse + + +async def make_client() -> Client: + address = os.environ["TEMPORAL_ADDRESS"] + namespace = os.environ["TEMPORAL_NAMESPACE"] + api_key = os.environ["TEMPORAL_API_KEY"] + return await Client.connect( + address, + namespace=namespace, + rpc_metadata={"authorization": f"Bearer {api_key}"}, + tls=True, + ) + + +async def pending() -> None: + client = await make_client() + any_found = False + async for wf in client.list_workflows( + 'WorkflowType="approvalWorkflow" AND ExecutionStatus="Running"' + ): + any_found = True + handle = client.get_workflow_handle(wf.id) + try: + req = await handle.query("pending-approval") + except Exception: # noqa: BLE001 + req = None + print(f"\n{wf.id} (started {wf.start_time})") + if req: + print(f" message: {req.message}") + print(f" diagnosis: {req.diagnosis}") + print(f" proposed: {req.proposedAction}") + print(f" approve: python -m client approve {wf.id} \"\"") + print(f" reject: python -m client reject {wf.id} \"\"") + else: + print(" (workflow exists but agent has not requested approval yet)") + if not any_found: + print("(no pending approval workflows)") + + +async def decide(decision: str, workflow_id: str, reason: str) -> None: + client = await make_client() + handle = client.get_workflow_handle(workflow_id) + response = ApprovalResponse(decision=decision, reason=reason) # type: ignore[arg-type] + await handle.signal("approval-decision", response) + print(f"signaled {workflow_id}: {decision} — {reason}") + + +async def trigger(alertname: str, service: str) -> None: + client = await make_client() + task_queue = os.environ.get("TEMPORAL_TASK_QUEUE", "triage-python") + workflow_id = f"triage-{alertname.lower()}-{service.lower()}" + alert = AlertPayload( + status="firing", + labels={"alertname": alertname, "service": service, "severity": "critical", "runbook": "synthetic"}, + annotations={ + "summary": f"Synthetic test alert for {service}", + "description": "Triggered manually via client.py to exercise the triage flow.", + }, + startsAt=datetime.now(timezone.utc).isoformat(), + ) + handle = await client.start_workflow( + IncidentTriageWorkflow.run, + alert, + id=workflow_id, + task_queue=task_queue, + start_signal="alert-update", + start_signal_args=[alert], + ) + print(f"started triage workflow: {handle.id} on {task_queue}") + + +def main() -> None: + args = sys.argv[1:] + if not args: + print("Usage: python -m client ...", file=sys.stderr) + sys.exit(1) + + cmd = args[0] + if cmd == "pending": + asyncio.run(pending()) + elif cmd == "approve": + if len(args) < 3: + print("Usage: python -m client approve ", file=sys.stderr); sys.exit(1) + asyncio.run(decide("approved", args[1], " ".join(args[2:]))) + elif cmd == "reject": + if len(args) < 3: + print("Usage: python -m client reject ", file=sys.stderr); sys.exit(1) + asyncio.run(decide("rejected", args[1], " ".join(args[2:]))) + elif cmd == "trigger": + if len(args) < 3: + print("Usage: python -m client trigger ", file=sys.stderr); sys.exit(1) + asyncio.run(trigger(args[1], args[2])) + else: + print(f"Unknown command: {cmd}", file=sys.stderr); sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tool_registry_incident_triage/pyproject.toml b/tool_registry_incident_triage/pyproject.toml new file mode 100644 index 00000000..4b77500a --- /dev/null +++ b/tool_registry_incident_triage/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "triage-python" +version = "0.0.1" +description = "Python triage worker for the temporal-incident-triage-gke showcase" +requires-python = ">=3.11" +dependencies = [ + # Pin the Temporal Python SDK from the unmerged feat/tool-registry branch. + # Once the SDK publishes, swap to: temporalio[tool-registry]>=X.Y.Z + "temporalio[tool-registry] @ file:///Users/alex/Documents/checkouts/temporal/sdk-python", + "anthropic>=0.40.0", + "httpx>=0.27.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.24.0", +] + +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +# Top-level Python files in this directory (no proper package layout — they're +# imported as flat modules from worker.py / client.py / tests/). +py-modules = ["triage_workflow", "approval_workflow", "triage_activity", "worker", "client", "triage_types"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/tool_registry_incident_triage/tests/__init__.py b/tool_registry_incident_triage/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tool_registry_incident_triage/tests/test_triage_activity.py b/tool_registry_incident_triage/tests/test_triage_activity.py new file mode 100644 index 00000000..3f10b99f --- /dev/null +++ b/tool_registry_incident_triage/tests/test_triage_activity.py @@ -0,0 +1,240 @@ +"""Unit tests for the Python triage activity's tool registry. + +Drives the registry directly with `MockProvider.run_loop` — bypasses +`agentic_session` (which would require a real Anthropic client). Asserts that +the agent's tool-call sequence produces the expected final result. + +No API keys, no Temporal, no shell exec, no MCP HTTP — all stubbed via the +injected `TriageDeps`. + +Mirrors workers/typescript/triage_activity.test.ts. +""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import pytest +from temporalio.contrib.tool_registry.testing import ResponseBuilder + +from triage_activity import build_triage_registry, TriageDeps +from triage_types import AlertPayload, ApprovalResponse + + +# ── Fixtures ──────────────────────────────────────────────────────────────── + + +def make_alert() -> AlertPayload: + return AlertPayload( + status="firing", + labels={"alertname": "HighLatencyP99", "service": "api", "runbook": "rollback-or-scale"}, + annotations={"summary": "P99 > 1s", "description": "P99 above threshold for 1m."}, + startsAt=datetime.now(timezone.utc).isoformat(), + ) + + +def make_deps(**overrides: Any) -> TriageDeps: + async def default_list(base_url: str) -> list[dict[str, Any]]: + if "7071" in base_url: + return [{ + "name": "prometheus_query", + "description": "instant PromQL query", + "inputSchema": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, + }] + return [{ + "name": "kubectl_describe", + "description": "describe a k8s resource", + "inputSchema": { + "type": "object", + "properties": {"resource": {"type": "string"}, "name": {"type": "string"}, "namespace": {"type": "string"}}, + "required": ["resource", "name"], + }, + }] + + async def default_call(_url: str, name: str, args: dict[str, Any]) -> str: + return f"(mocked {name} → {args})" + + async def default_approve(_alert: AlertPayload, _req: Any) -> ApprovalResponse: + return ApprovalResponse(decision="approved", reason="default-mock") + + async def default_exec(cmd: str) -> tuple[str, str]: + return f"(mocked exec: {cmd})", "" + + deps = TriageDeps( + mcp_list_tools=overrides.get("mcp_list_tools", default_list), + mcp_call_tool=overrides.get("mcp_call_tool", default_call), + request_human_approval=overrides.get("request_human_approval", default_approve), + exec_shell_command=overrides.get("exec_shell_command", default_exec), + ) + return deps + + +class FakeSession: + """Stub for AgenticSession with just .results so build_triage_registry works.""" + def __init__(self) -> None: + self.results: list[Any] = [] + + +async def async_run_loop(script: list[dict[str, Any]], registry: Any) -> None: + """Async variant of MockProvider.run_loop. + + The shipped MockProvider uses sync `registry.dispatch()` which rejects async + handlers (TypeError). Our triage handlers are async (httpx, asyncio + subprocess, Temporal client). This helper iterates the same script but + calls `await registry.adispatch(...)` instead. + """ + for response in script: + if response.get("_mock_stop"): + return + for block in response.get("content", []): + if block.get("type") == "tool_use": + await registry.adispatch(block["name"], block.get("input", {})) + + +async def drive(deps: TriageDeps, script: list[dict[str, Any]]) -> tuple[Any, list[Any]]: + session = FakeSession() + registry, get_result = await build_triage_registry(make_alert(), session, deps) + await async_run_loop(script, registry) + return get_result(), session.results + + +# ── Tests ─────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_happy_path_resolved(): + """investigate → propose → approve → execute → report_resolved.""" + approval_calls = 0 + async def counting_approve(_alert: AlertPayload, _req: Any) -> ApprovalResponse: + nonlocal approval_calls + approval_calls += 1 + return ApprovalResponse(decision="approved", reason="go ahead") + + deps = make_deps(request_human_approval=counting_approve) + action = "kubectl rollout restart deploy/api -n demo-app" + + result, mcp_results = await drive(deps, [ + ResponseBuilder.tool_call("prometheus_query", {"query": "up{service='api'}"}), + ResponseBuilder.tool_call("kubectl_describe", {"resource": "pod", "name": "api-xyz", "namespace": "demo-app"}), + ResponseBuilder.tool_call("propose_remediation", {"action": action, "justification": "leak; restart reclaims memory"}), + ResponseBuilder.tool_call("request_human_approval", { + "message": "Restart api?", "diagnosis": "memory leak", "proposedAction": action, + }), + ResponseBuilder.tool_call("execute_remediation", {"action": action}), + ResponseBuilder.tool_call("report_resolved", {"summary": "restarted; latency normal"}), + ResponseBuilder.done("done"), + ]) + + assert result.status == "resolved" + assert "restart" in result.summary + assert len(result.remediations) == 1 + assert result.remediations[0].action == action + assert approval_calls == 1 + kinds = [r["kind"] for r in mcp_results] + assert kinds == ["remediation", "approval", "executed", "final"] + + +@pytest.mark.asyncio +async def test_rejected_approval_unresolved(): + """Operator rejects → agent reports unresolved with reason in session results.""" + async def reject(_alert: AlertPayload, _req: Any) -> ApprovalResponse: + return ApprovalResponse(decision="rejected", reason="off-hours; defer until tomorrow") + + deps = make_deps(request_human_approval=reject) + + result, mcp_results = await drive(deps, [ + ResponseBuilder.tool_call("propose_remediation", {"action": "kubectl scale ...", "justification": "transient"}), + ResponseBuilder.tool_call("request_human_approval", { + "message": "Scale?", "diagnosis": "transient", "proposedAction": "kubectl scale ...", + }), + ResponseBuilder.tool_call("report_unresolved", {"summary": "operator deferred"}), + ResponseBuilder.done("done"), + ]) + + assert result.status == "unresolved" + assert "deferred" in result.summary + approval = next((r for r in mcp_results if r.get("kind") == "approval"), None) + assert approval is not None + assert approval["decision"] == "rejected" + assert "off-hours" in approval["reason"] + + +@pytest.mark.asyncio +async def test_execute_refuses_without_approval(): + """Guard rail: execute_remediation rejects calls when no approval is in flight.""" + deps = make_deps() + result, _ = await drive(deps, [ + ResponseBuilder.tool_call("execute_remediation", {"action": "rm -rf /"}), + ResponseBuilder.tool_call("report_unresolved", {"summary": "tried to skip approval"}), + ResponseBuilder.done("done"), + ]) + assert result.status == "unresolved" + + +@pytest.mark.asyncio +async def test_execute_refuses_when_action_does_not_match(): + """Guard rail: execute_remediation rejects calls whose action ≠ approved one.""" + executed_cmd: list[str] = [] + async def record_exec(cmd: str) -> tuple[str, str]: + executed_cmd.append(cmd) + return "ran", "" + + deps = make_deps( + request_human_approval=lambda a, r: _approve(a, r), + exec_shell_command=record_exec, + ) + + async def _approve(_alert: AlertPayload, _req: Any) -> ApprovalResponse: + return ApprovalResponse(decision="approved", reason="ok") + + result, _ = await drive(deps, [ + ResponseBuilder.tool_call("propose_remediation", {"action": "kubectl restart api", "justification": "x"}), + ResponseBuilder.tool_call("request_human_approval", { + "message": "Restart?", "diagnosis": "x", "proposedAction": "kubectl restart api", + }), + # Agent attempts a DIFFERENT action than what was approved. + ResponseBuilder.tool_call("execute_remediation", {"action": "kubectl scale deploy/api --replicas=10"}), + ResponseBuilder.tool_call("report_unresolved", {"summary": "guard tripped"}), + ResponseBuilder.done("done"), + ]) + + assert result.status == "unresolved" + assert executed_cmd == [], "exec_shell_command should not have been called" + + +@pytest.mark.asyncio +async def test_mcp_tools_registered(): + """Both MCP sidecars' tools + per-language tools all appear in the registry.""" + deps = make_deps() + session = FakeSession() + registry, _ = await build_triage_registry(make_alert(), session, deps) + schemas = registry.to_anthropic() + names = [t["name"] for t in schemas] + for expected in [ + "prometheus_query", "kubectl_describe", + "propose_remediation", "request_human_approval", + "execute_remediation", "report_resolved", "report_unresolved", + ]: + assert expected in names, f"{expected} should be in registry" + + +@pytest.mark.asyncio +async def test_mcp_dispatch_forwards_to_sidecar(): + """Tool dispatch reaches mcp_call_tool with the right URL + name + args.""" + calls: list[dict[str, Any]] = [] + async def record_call(url: str, name: str, args: dict[str, Any]) -> str: + calls.append({"url": url, "name": name, "args": args}) + return f"result for {name}" + + deps = make_deps(mcp_call_tool=record_call) + + await drive(deps, [ + ResponseBuilder.tool_call("prometheus_query", {"query": "up{}"}), + ResponseBuilder.tool_call("report_unresolved", {"summary": "test"}), + ResponseBuilder.done("done"), + ]) + + assert len(calls) == 1 + assert calls[0]["name"] == "prometheus_query" + assert calls[0]["args"] == {"query": "up{}"} + assert "7071" in calls[0]["url"] diff --git a/tool_registry_incident_triage/triage_activity.py b/tool_registry_incident_triage/triage_activity.py new file mode 100644 index 00000000..5a62c296 --- /dev/null +++ b/tool_registry_incident_triage/triage_activity.py @@ -0,0 +1,385 @@ +"""triage_incident_activity — the agentic loop (Python port). + +Mirrors workers/typescript/activities/triage.ts: + - Pulls Prometheus + Kubernetes tools from MCP sidecars (localhost:7071/7072) + via JSON-RPC over HTTP, registers them on the ToolRegistry. + - Defines per-language tools: propose_remediation, request_human_approval, + execute_remediation, report_resolved, report_unresolved. + - Opens an agentic_session, runs the loop, returns the parsed result. + +Structure for testability: + - build_triage_registry() returns the (registry, get_result) pair. Pure-ish: + takes all I/O dependencies as injected callables so unit tests can + substitute them. + - triage_incident_activity() opens the agentic_session, calls + build_triage_registry with real deps, runs the LLM loop. +""" +from __future__ import annotations + +import asyncio +import dataclasses +import json +import os +import subprocess +from typing import Any, Awaitable, Callable + +import httpx +from temporalio import activity +from temporalio.client import Client +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.contrib.tool_registry import ( + ToolRegistry, + agentic_session, +) + +from approval_workflow import ApprovalWorkflow +from triage_types import ( + AlertPayload, + ApprovalRequest, + ApprovalResponse, + ProposedRemediation, + TriageResult, +) + + +SYSTEM_PROMPT = """You are an SRE on-call agent triaging a production alert. + +You have these tools (sourced from MCP sidecars + per-language helpers): + - prometheus_query(query) instant PromQL query + - prometheus_query_range(query, start, end, step) + - prometheus_alerts() what is currently firing + - kubectl_get(resource, namespace?) list K8s resources + - kubectl_describe(resource, name, namespace?) + - kubectl_logs(pod, namespace, tail?) + - propose_remediation(action, justification) record but do NOT execute + - request_human_approval(message, diagnosis, proposedAction) + blocks until operator says approve|reject + - execute_remediation(action) ONLY callable AFTER approval was approved. + Pass the same action you got approved. + - report_resolved(summary) ends the loop with status=resolved + - report_unresolved(summary) ends the loop with status=unresolved + +Workflow: + 1. Read the alert. Use prometheus_query to confirm the symptom is currently true. + 2. Use kubectl_get/describe/logs and prometheus_query_range to find root cause. + 3. propose_remediation with a specific action. + 4. request_human_approval, attaching your diagnosis and the proposed action. + 5. If approved: execute_remediation, then prometheus_query to verify, then report_resolved. + 6. If rejected: report_unresolved with the operator's reason. + +Be terse. Conversation history is heartbeated to Temporal — keep tool inputs short. +""" + + +# ── Injectable dependencies (override in tests) ──────────────────────────── + + +@dataclasses.dataclass +class TriageDeps: + """Pluggable I/O for the triage activity. Tests substitute their own.""" + + mcp_list_tools: Callable[[str], Awaitable[list[dict[str, Any]]]] + mcp_call_tool: Callable[[str, str, dict[str, Any]], Awaitable[str]] + request_human_approval: Callable[[AlertPayload, ApprovalRequest], Awaitable[ApprovalResponse]] + exec_shell_command: Callable[[str], Awaitable[tuple[str, str]]] + + +async def _mcp_list_tools(base_url: str) -> list[dict[str, Any]]: + async with httpx.AsyncClient(timeout=5.0) as client: + r = await client.post( + base_url, + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + ) + data = r.json() + if "error" in data: + raise RuntimeError(f"mcp tools/list {base_url}: {data['error']['message']}") + return data.get("result", {}).get("tools", []) or [] + + +async def _mcp_call_tool(base_url: str, name: str, args: dict[str, Any]) -> str: + async with httpx.AsyncClient(timeout=30.0) as client: + r = await client.post( + base_url, + json={ + "jsonrpc": "2.0", + "id": int(asyncio.get_event_loop().time() * 1000), + "method": "tools/call", + "params": {"name": name, "arguments": args}, + }, + ) + data = r.json() + if "error" in data: + return f"MCP error: {data['error']['message']}" + blocks = data.get("result", {}).get("content", []) or [] + return "\n".join(b.get("text", "") for b in blocks) + + +async def _exec_shell_command(cmd: str) -> tuple[str, str]: + proc = await asyncio.create_subprocess_shell( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=60) + except asyncio.TimeoutError: + proc.kill() + raise + return stdout.decode("utf-8", errors="replace"), stderr.decode("utf-8", errors="replace") + + +def default_deps() -> TriageDeps: + return TriageDeps( + mcp_list_tools=_mcp_list_tools, + mcp_call_tool=_mcp_call_tool, + request_human_approval=_real_request_human_approval, + exec_shell_command=_exec_shell_command, + ) + + +PROM_MCP = os.environ.get("MCP_PROMETHEUS_URL", "http://localhost:7071/") +K8S_MCP = os.environ.get("MCP_KUBERNETES_URL", "http://localhost:7072/") + + +# ── Registry builder (testable surface) ────────────────────────────────────── + + +async def build_triage_registry( + alert: AlertPayload, + session: Any, # AgenticSession or test stub with .results: list + deps: TriageDeps, + *, + prom_mcp: str = PROM_MCP, + k8s_mcp: str = K8S_MCP, +) -> tuple[ToolRegistry, Callable[[], TriageResult | None]]: + """Build a populated ToolRegistry plus a get_result() accessor. + + Pure modulo deps — MockProvider.run_loop(messages, registry) drives the + registry without any real MCP, Temporal, or shell dependency. + """ + registry = ToolRegistry() + + # MCP-sourced tools. + try: + prom_tools = await deps.mcp_list_tools(prom_mcp) + except Exception: + prom_tools = [] + try: + k8s_tools = await deps.mcp_list_tools(k8s_mcp) + except Exception: + k8s_tools = [] + + for tool in prom_tools: + name = tool["name"] + + def make_handler(n: str) -> Callable[[dict[str, Any]], Awaitable[str]]: + async def h(inp: dict[str, Any]) -> str: + return await deps.mcp_call_tool(prom_mcp, n, inp) + return h + + registry.handler({ + "name": name, + "description": tool.get("description", ""), + "input_schema": tool.get("inputSchema", {"type": "object"}), + })(make_handler(name)) + + for tool in k8s_tools: + name = tool["name"] + + def make_handler(n: str) -> Callable[[dict[str, Any]], Awaitable[str]]: + async def h(inp: dict[str, Any]) -> str: + return await deps.mcp_call_tool(k8s_mcp, n, inp) + return h + + registry.handler({ + "name": name, + "description": tool.get("description", ""), + "input_schema": tool.get("inputSchema", {"type": "object"}), + })(make_handler(name)) + + # Per-language tools. + remediations: list[ProposedRemediation] = [] + approved_action: str | None = None + final: TriageResult | None = None + + @registry.handler({ + "name": "propose_remediation", + "description": "Record a remediation you would apply. Does NOT execute it.", + "input_schema": { + "type": "object", + "properties": {"action": {"type": "string"}, "justification": {"type": "string"}}, + "required": ["action", "justification"], + }, + }) + def propose(inp: dict[str, Any]) -> str: + r = ProposedRemediation(action=str(inp["action"]), justification=str(inp["justification"])) + remediations.append(r) + session.results.append({"kind": "remediation", **dataclasses.asdict(r)}) + return "recorded" + + @registry.handler({ + "name": "request_human_approval", + "description": "Block until operator decides. Returns JSON {decision, reason}.", + "input_schema": { + "type": "object", + "properties": { + "message": {"type": "string"}, + "diagnosis": {"type": "string"}, + "proposedAction": {"type": "string"}, + }, + "required": ["message", "diagnosis", "proposedAction"], + }, + }) + async def request_approval(inp: dict[str, Any]) -> str: + nonlocal approved_action + req = ApprovalRequest( + message=str(inp["message"]), + diagnosis=str(inp["diagnosis"]), + proposedAction=str(inp["proposedAction"]), + ) + response = await deps.request_human_approval(alert, req) + if response.decision == "approved": + approved_action = req.proposedAction + session.results.append({"kind": "approval", **dataclasses.asdict(response)}) + return json.dumps(dataclasses.asdict(response)) + + @registry.handler({ + "name": "execute_remediation", + "description": "Execute the previously-approved action. Errors if no approval has been granted.", + "input_schema": { + "type": "object", + "properties": {"action": {"type": "string"}}, + "required": ["action"], + }, + }) + async def execute(inp: dict[str, Any]) -> str: + action = str(inp["action"]) + if approved_action is None: + return "ERROR: no approval has been granted. Call request_human_approval first." + if action != approved_action: + return f"ERROR: requested action does not match approved action. Approved: {approved_action}" + try: + stdout, stderr = await deps.exec_shell_command(action) + session.results.append({ + "kind": "executed", + "action": action, + "stdout": stdout[:2000], + "stderr": stderr[:2000], + }) + return (stdout or stderr or "ok")[:4000] + except Exception as e: # noqa: BLE001 + return f"EXEC ERROR: {e}" + + @registry.handler({ + "name": "report_resolved", + "description": "Ends the loop with status=resolved.", + "input_schema": { + "type": "object", + "properties": {"summary": {"type": "string"}}, + "required": ["summary"], + }, + }) + def report_resolved(inp: dict[str, Any]) -> str: + nonlocal final + final = TriageResult(status="resolved", summary=str(inp["summary"]), remediations=list(remediations)) + session.results.append({"kind": "final", **dataclasses.asdict(final)}) + return "ok" + + @registry.handler({ + "name": "report_unresolved", + "description": "Ends the loop with status=unresolved.", + "input_schema": { + "type": "object", + "properties": {"summary": {"type": "string"}}, + "required": ["summary"], + }, + }) + def report_unresolved(inp: dict[str, Any]) -> str: + nonlocal final + final = TriageResult(status="unresolved", summary=str(inp["summary"]), remediations=list(remediations)) + session.results.append({"kind": "final", **dataclasses.asdict(final)}) + return "ok" + + return registry, lambda: final + + +def build_prompt(alert: AlertPayload) -> str: + return ( + f"Alert fired: {alert.labels.get('alertname')} on {alert.labels.get('service', 'unknown')}.\n" + f"Summary: {alert.annotations.get('summary', '(none)')}\n" + f"Description: {alert.annotations.get('description', '(none)')}\n" + f"Runbook hint: {alert.labels.get('runbook', '(none)')}\n\n" + "Investigate, propose, get approval, and either fix or report unresolved." + ) + + +# ── Activity entrypoint ───────────────────────────────────────────────────── + + +@activity.defn(name="triage_incident_activity") +async def triage_incident_activity(alert: AlertPayload) -> TriageResult: + deps = default_deps() + async with agentic_session() as session: + registry, get_result = await build_triage_registry(alert, session, deps) + await session.run_tool_loop( + registry=registry, + provider="anthropic", + system=SYSTEM_PROMPT, + prompt=build_prompt(alert), + ) + final = get_result() + if final is None: + raise RuntimeError("Agent ended the loop without calling report_resolved or report_unresolved") + return final + + +# ── Real HITL bridge ───────────────────────────────────────────────────────── + + +async def _real_request_human_approval( + alert: AlertPayload, request: ApprovalRequest +) -> ApprovalResponse: + """signal_with_start ApprovalWorkflow with deterministic ID per alert group.""" + api_key = os.environ.get("TEMPORAL_API_KEY") + address = os.environ.get("TEMPORAL_ADDRESS") + namespace = os.environ.get("TEMPORAL_NAMESPACE") + if not (api_key and address and namespace): + raise RuntimeError("Missing TEMPORAL_ADDRESS / TEMPORAL_NAMESPACE / TEMPORAL_API_KEY") + + client = await Client.connect( + address, + namespace=namespace, + rpc_metadata={"authorization": f"Bearer {api_key}"}, + tls=True, + ) + + key = f"{alert.labels.get('alertname', 'unknown')}-{alert.labels.get('service', 'unknown')}" + approval_workflow_id = f"approval-{key.lower()}" + task_queue = os.environ.get("TEMPORAL_TASK_QUEUE", "triage-python") + + handle = await client.start_workflow( + ApprovalWorkflow.run, + key, + id=approval_workflow_id, + task_queue=task_queue, + start_signal="approval-request", + start_signal_args=[request], + # If the activity retries while the approval workflow is still running, + # attach to the existing one rather than starting a new approval. The + # operator should not get a second prompt for the same incident. + id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING, + ) + + # Heartbeat every 30 seconds while waiting on the approval workflow. + # AgenticSession only heartbeats between LLM turns, so a multi-hour + # operator wait inside this handler would otherwise trigger heartbeat + # timeout in 120s and kill the activity. The ticker keeps the activity + # alive until the operator decides. + async def _ticker() -> None: + while True: + await asyncio.sleep(30) + activity.heartbeat() + + ticker_task = asyncio.create_task(_ticker()) + try: + return await handle.result() + finally: + ticker_task.cancel() diff --git a/tool_registry_incident_triage/triage_types.py b/tool_registry_incident_triage/triage_types.py new file mode 100644 index 00000000..cc71ef1a --- /dev/null +++ b/tool_registry_incident_triage/triage_types.py @@ -0,0 +1,41 @@ +"""Shared types between workflow, activity, and client.""" +from __future__ import annotations + +import dataclasses +from typing import Literal + + +@dataclasses.dataclass +class AlertPayload: + status: str + labels: dict[str, str] + annotations: dict[str, str] + startsAt: str + endsAt: str | None = None + fingerprint: str | None = None + + +@dataclasses.dataclass +class ProposedRemediation: + action: str + justification: str + + +@dataclasses.dataclass +class TriageResult: + status: Literal["resolved", "unresolved"] + summary: str + remediations: list[ProposedRemediation] + + +@dataclasses.dataclass +class ApprovalRequest: + message: str + diagnosis: str + proposedAction: str + + +@dataclasses.dataclass +class ApprovalResponse: + decision: Literal["approved", "rejected"] + reason: str diff --git a/tool_registry_incident_triage/triage_workflow.py b/tool_registry_incident_triage/triage_workflow.py new file mode 100644 index 00000000..5d4e5a77 --- /dev/null +++ b/tool_registry_incident_triage/triage_workflow.py @@ -0,0 +1,50 @@ +"""Single-activity workflow that delegates the agentic loop to the triage activity. + +Workflow ID is set deterministically by the webhook receiver +(triage-${alertname}-${service}), so re-fires from AlertManager re-attach +to the running workflow rather than spawning a new one. +""" +from __future__ import annotations + +from datetime import timedelta + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from triage_activity import triage_incident_activity + from triage_types import AlertPayload, TriageResult + + +@workflow.defn(name="incidentTriageWorkflow") +class IncidentTriageWorkflow: + def __init__(self) -> None: + self._current_alert: AlertPayload | None = None + self._result: TriageResult | None = None + + @workflow.run + async def run(self, initial_alert: AlertPayload) -> TriageResult: + self._current_alert = initial_alert + # Single activity — matches lexicon-temporal's `agenticHitl` profile: + # 8h start-to-close (operator may take hours), 120s heartbeat (Claude + # turn worst case), 1 attempt (AgenticSession heartbeat is the resume). + self._result = await workflow.execute_activity( + triage_incident_activity, + self._current_alert, + start_to_close_timeout=timedelta(hours=8), + heartbeat_timeout=timedelta(seconds=120), + ) + return self._result + + @workflow.signal(name="alert-update") + def alert_update(self, alert: AlertPayload) -> None: + # Webhook may re-fire with refreshed alert state. The agent reads + # the latest via the current-alert query. + self._current_alert = alert + + @workflow.query(name="current-alert") + def current_alert(self) -> AlertPayload | None: + return self._current_alert + + @workflow.query(name="triage-result") + def triage_result(self) -> TriageResult | None: + return self._result diff --git a/tool_registry_incident_triage/worker.py b/tool_registry_incident_triage/worker.py new file mode 100644 index 00000000..f7258966 --- /dev/null +++ b/tool_registry_incident_triage/worker.py @@ -0,0 +1,52 @@ +"""Temporal worker for the Python triage workflow. + +Connects to Temporal Cloud, polls the task queue from TEMPORAL_TASK_QUEUE +(typically `triage-python`), registers IncidentTriageWorkflow + ApprovalWorkflow ++ the triage activity. +""" +from __future__ import annotations + +import asyncio +import os +import sys + +from temporalio.client import Client +from temporalio.worker import Worker + +from approval_workflow import ApprovalWorkflow +from triage_activity import triage_incident_activity +from triage_workflow import IncidentTriageWorkflow + + +async def main() -> None: + address = os.environ.get("TEMPORAL_ADDRESS") + namespace = os.environ.get("TEMPORAL_NAMESPACE") + api_key = os.environ.get("TEMPORAL_API_KEY") + task_queue = os.environ.get("TEMPORAL_TASK_QUEUE", "triage-python") + + if not (address and namespace and api_key): + print("Missing TEMPORAL_ADDRESS / TEMPORAL_NAMESPACE / TEMPORAL_API_KEY", file=sys.stderr) + sys.exit(1) + + print(f"connecting to {address} (ns={namespace}) on task queue {task_queue}") + + client = await Client.connect( + address, + namespace=namespace, + rpc_metadata={"authorization": f"Bearer {api_key}"}, + tls=True, + ) + + worker = Worker( + client, + task_queue=task_queue, + workflows=[IncidentTriageWorkflow, ApprovalWorkflow], + activities=[triage_incident_activity], + ) + + print(f"worker ready — polling {task_queue}") + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main())