Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from psycopg_pool import AsyncConnectionPool

from agent.models import get_embedding, get_llm
from agent.profiles import ProfileName, create_profile_graphs
from agent.profiles import ProfileName
from agent.profiles.base import InputState, OutputState
from agent.profiles.cross_database import create_cross_database_graph
from agent.profiles.react_to_me import create_reactome_graph
from mcp.mcp_tools import create_mcp_tools
from util.logging import logging

LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
Expand All @@ -33,9 +36,9 @@ def __init__(
llm: BaseChatModel = get_llm("openai", "gpt-4o-mini")
embedding: Embeddings = get_embedding("openai", "text-embedding-3-large")

self.uncompiled_graph: dict[str, StateGraph] = create_profile_graphs(
profiles, llm, embedding
)
self.llm = llm
self.embedding = embedding
self.profiles = profiles

# The following are set asynchronously by calling initialize()
self.graph: dict[str, CompiledStateGraph] | None = None
Expand All @@ -46,10 +49,27 @@ def __del__(self) -> None:
asyncio.run(self.close_pool())

async def initialize(self) -> dict[str, CompiledStateGraph]:

mcp_tools, self.mcp_manager = await create_mcp_tools(
os.getenv("MCP_SERVER_PATH")
)

uncompiled_graphs: dict[str, StateGraph] = {}
for profile in map(str.lower, self.profiles):
if profile == ProfileName.React_to_Me.lower():
uncompiled_graphs[profile] = create_reactome_graph(
self.llm, self.embedding, mcp_tools
)
elif profile == ProfileName.Cross_Database_Prototype.lower():
uncompiled_graphs[profile] = create_cross_database_graph(
self.llm, self.embedding
)

checkpointer: BaseCheckpointSaver[str] = await self.create_checkpointer()

return {
profile: graph.compile(checkpointer=checkpointer)
for profile, graph in self.uncompiled_graph.items()
for profile, graph in uncompiled_graphs.items()
}

async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
Expand All @@ -73,6 +93,8 @@ async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
async def close_pool(self) -> None:
if self.pool:
await self.pool.close()
if self.mcp_manager:
await self.mcp_manager.stop()

async def ainvoke(
self,
Expand Down
97 changes: 82 additions & 15 deletions src/agent/profiles/react_to_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import StateGraph

Expand All @@ -20,9 +20,18 @@ def __init__(
self,
llm: BaseChatModel,
embedding: Embeddings,
mcp_tools: list | None = None,
) -> None:
super().__init__(llm, embedding)

self.llm = llm

# optional MCP tools - if provided the LLM can call them instead of using RAG
self.mcp_tools = mcp_tools or []

# bind tools to LLM once at init time, reused on every call_model invocation
self.llm_with_tools = self.llm.bind_tools(self.mcp_tools) if self.mcp_tools else None

# Create runnables (tasks & tools)
self.unsafe_answer_generator: Runnable = create_unsafe_answer_generator(
llm, streaming=True
Expand Down Expand Up @@ -73,28 +82,86 @@ async def generate_unsafe_response(
async def call_model(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
else [HumanMessage(state["user_input"])]
),
},
config,
)
# no MCP tools - fall back to existing RAG behaviour unchanged
if not self.mcp_tools:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
else [HumanMessage(state["user_input"])]
),
},
config,
)
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
answer=result["answer"],
)

llm_with_tools =self.llm_with_tools

messages = list(state["chat_history"] or []) + [
HumanMessage(state["rephrased_input"])
]

response = await llm_with_tools.ainvoke(messages, config)

# tool calling loop - max 10 iterations to prevent infinite loop
max_iterations = 15
iteration = 0
while response.tool_calls and iteration < max_iterations:
iteration += 1
tool_results = []

for tool_call in response.tool_calls:
# find matching tool and execute it - triggers MCP client to Reactome API
tool = next(
t for t in self.mcp_tools if t.name == tool_call["name"]
)

result = await tool.ainvoke(tool_call["args"])

# tool_call_id links this result back to the specific request the LLM made
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call["id"],
)
)

# send tool results back to LLM for final answer
messages = messages + [response] + tool_results
response = await llm_with_tools.ainvoke(messages, config)

# loop hit max iterations - LLM never gave direct answer
if response.tool_calls:
answer = "I was unable to complete the research in time. Please try rephrasing your question."
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(answer),
],
answer=answer,
)

# LLM gave direct answer
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
AIMessage(response.content),
],
answer=result["answer"],
answer=response.content,
)


def create_reactome_graph(
llm: BaseChatModel,
embedding: Embeddings,
mcp_tools: list | None = None
) -> StateGraph:
return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph
return ReactToMeGraphBuilder(llm, embedding, mcp_tools).uncompiled_graph
Empty file added src/mcp/__init__.py
Empty file.
95 changes: 95 additions & 0 deletions src/mcp/mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import asyncio
import json


class MCPToolError(Exception):
"""Raised when MCP server returns a JSON-RPC error response."""
pass


class MCPClient:
"""
Minimal JSON-RPC client for communicating with the Reactome MCP server
over stdin/stdout.
"""

def __init__(self, process: asyncio.subprocess.Process, timeout: float = 30.0):
self.process = process
self.timeout = timeout
self.request_id = 0

async def call(self, method: str, params: dict | None = None) -> dict:
"""
Send a JSON-RPC request and return the result.

Raises
------
MCPToolError
If the server returns a JSON-RPC error response.
asyncio.TimeoutError
If the server does not respond within timeout seconds.
RuntimeError
If the server closes the connection unexpectedly.
"""
if params is None:
params = {}

self.request_id += 1

request = {
"jsonrpc": "2.0",
"id": self.request_id,
"method": method,
"params": params,
}

message = json.dumps(request) + "\n"
self.process.stdin.write(message.encode("utf-8"))
await self.process.stdin.drain()

# Wait for response with timeout so chatbot never hangs indefinitely
response_line = await asyncio.wait_for(
self.process.stdout.readline(),
timeout=self.timeout,
)

if not response_line:
raise RuntimeError("MCP server closed the connection.")

try:
response = json.loads(response_line.decode("utf-8").strip())
except json.JSONDecodeError as e:
raise RuntimeError(f"MCP server returned invalid JSON: {e}")

# JSON-RPC error response — server understood request but returned an error
if "error" in response:
error = response["error"]
raise MCPToolError(
f"MCP error {error.get('code')}: {error.get('message')}"
)

return response.get("result", {})

async def call_tool(self, tool_name: str, arguments: dict | None = None) -> str:
"""
Call a specific MCP tool and return the text result.

Parameters
----------
tool_name : str
Name of the tool (e.g. "reactome_search").
arguments : dict | None
Tool arguments.
"""
if arguments is None:
arguments = {}

result = await self.call(
"tools/call",
{"name": tool_name, "arguments": arguments},
)

# MCP returns content as list of typed blocks — extract text blocks
content = result.get("content", [])
text_parts = [block["text"] for block in content if block.get("type") == "text"]
return "\n".join(text_parts)
66 changes: 66 additions & 0 deletions src/mcp/mcp_process_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio
from pathlib import Path


class MCPConnectionError(Exception):
"""Raised when MCP server fails to start or crashes."""
pass


class MCPProcessManager:
"""Manages lifecycle of the Reactome MCP server process."""

def __init__(self, mcp_server_path: str):
self.mcp_server_path = Path(mcp_server_path)
if not self.mcp_server_path.exists():
raise FileNotFoundError(
f"MCP server not found at: {self.mcp_server_path}\n"
f"Make sure reactome-mcp is cloned and built with 'npm run build'"
)
self.process = None

async def start(self) -> asyncio.subprocess.Process:
"""Start the MCP server process."""
self.process = await asyncio.create_subprocess_exec(
"node",
str(self.mcp_server_path),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

# Allow server time to initialize before checking if it survived
await asyncio.sleep(1)

if self.process.returncode is not None:
# Process already exited — read stderr to find out why
stderr_output = await self.process.stderr.read()
raise MCPConnectionError(
f"MCP server failed to start:\n{stderr_output.decode('utf-8')}"
)

return self.process

async def stop(self) -> None:
"""Stop the MCP server — graceful terminate, falls back to kill."""
if not self.process:
return

try:
self.process.terminate()
await asyncio.wait_for(self.process.wait(), timeout=5.0)

except asyncio.TimeoutError:
self.process.kill()
await self.process.wait()

finally:
self.process = None

async def __aenter__(self):
await self.start()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.stop()
return False
Loading