Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from psycopg_pool import AsyncConnectionPool

from agent.models import get_embedding, get_llm
from agent.profiles import ProfileName, create_profile_graphs
from agent.profiles import ProfileName
from agent.profiles.base import InputState, OutputState
from agent.profiles.cross_database import create_cross_database_graph
from agent.profiles.react_to_me import create_reactome_graph
from mcp.mcp_tools import create_mcp_tools
from util.logging import logging

LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
Expand All @@ -33,9 +36,9 @@ def __init__(
llm: BaseChatModel = get_llm("openai", "gpt-4o-mini")
embedding: Embeddings = get_embedding("openai", "text-embedding-3-large")

self.uncompiled_graph: dict[str, StateGraph] = create_profile_graphs(
profiles, llm, embedding
)
self.llm = llm
self.embedding = embedding
self.profiles = profiles

# The following are set asynchronously by calling initialize()
self.graph: dict[str, CompiledStateGraph] | None = None
Expand All @@ -46,10 +49,27 @@ def __del__(self) -> None:
asyncio.run(self.close_pool())

async def initialize(self) -> dict[str, CompiledStateGraph]:

mcp_tools, self.mcp_manager = await create_mcp_tools(
os.getenv("MCP_SERVER_PATH")
)

uncompiled_graphs: dict[str, StateGraph] = {}
for profile in map(str.lower, self.profiles):
if profile == ProfileName.React_to_Me.lower():
uncompiled_graphs[profile] = create_reactome_graph(
self.llm, self.embedding, mcp_tools
)
elif profile == ProfileName.Cross_Database_Prototype.lower():
uncompiled_graphs[profile] = create_cross_database_graph(
self.llm, self.embedding
)

checkpointer: BaseCheckpointSaver[str] = await self.create_checkpointer()

return {
profile: graph.compile(checkpointer=checkpointer)
for profile, graph in self.uncompiled_graph.items()
for profile, graph in uncompiled_graphs.items()
}

async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
Expand All @@ -73,6 +93,8 @@ async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
async def close_pool(self) -> None:
if self.pool:
await self.pool.close()
if self.mcp_manager:
await self.mcp_manager.stop()

async def ainvoke(
self,
Expand Down
139 changes: 124 additions & 15 deletions src/agent/profiles/react_to_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.graph.state import StateGraph

from agent.profiles.base import BaseGraphBuilder, BaseState
from agent.tasks.unsafe_question import create_unsafe_answer_generator
from retrievers.reactome.rag import create_reactome_rag
from agent.models import get_llm
from mcp.query_router import create_query_router, ROUTE_RAG, ROUTE_MCP_SEARCH, ROUTE_MCP_ANALYSIS


class ReactToMeState(BaseState):
Expand All @@ -20,9 +22,29 @@ def __init__(
self,
llm: BaseChatModel,
embedding: Embeddings,
mcp_tools: list | None = None,
) -> None:
super().__init__(llm, embedding)

self.llm = llm

# optional MCP tools - if provided the LLM can call them instead of using RAG
self.mcp_tools = mcp_tools or []

# pre-bind two route-specific tool subsets to LLM once at init time - search tools for
# lookup/retrieval routes and analysis tool for enrichment routes - avoids rebinding on every message
# produces two separate LLM instances: llm_with_search_tools and llm_with_analysis_tools
search_tools = [t for t in self.mcp_tools if t.name in (
"search_reactome", "get_pathway", "get_database_info", "get_species"
)]
analysis_tools = [t for t in self.mcp_tools if t.name == "analyze_identifiers"]

self.llm_with_search_tools = self.llm.bind_tools(search_tools) if search_tools else None
self.llm_with_analysis_tools = self.llm.bind_tools(analysis_tools) if analysis_tools else None

# create router with cheap model - only used when mcp_tools available
self.query_router = create_query_router(get_llm("openai", "gpt-4o-mini")) if self.mcp_tools else None

# Create runnables (tasks & tools)
self.unsafe_answer_generator: Runnable = create_unsafe_answer_generator(
llm, streaming=True
Expand Down Expand Up @@ -73,28 +95,115 @@ async def generate_unsafe_response(
async def call_model(
self, state: ReactToMeState, config: RunnableConfig
) -> ReactToMeState:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
else [HumanMessage(state["user_input"])]
),
},
config,
)
# no MCP tools - fall back to existing RAG behaviour unchanged
if not self.mcp_tools:
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
else [HumanMessage(state["user_input"])]
),
},
config,
)
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
answer=result["answer"],
)

# route question to correct path
route = await self.query_router(state["rephrased_input"])

if route == ROUTE_RAG:
# question is general knowledge, use RAG directly, no tools needed
result: dict[str, Any] = await self.reactome_rag.ainvoke(
{
"input": state["rephrased_input"],
"chat_history": (
state["chat_history"]
if state["chat_history"]
else [HumanMessage(state["user_input"])]
),
},
config,
)
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
],
answer=result["answer"],
)

if route == ROUTE_MCP_SEARCH:
llm_with_tools = self.llm_with_search_tools
elif route == ROUTE_MCP_ANALYSIS:
llm_with_tools = self.llm_with_analysis_tools
else:
llm_with_tools = self.llm_with_search_tools

messages = list(state["chat_history"] or []) + [
HumanMessage(state["rephrased_input"])
]

response = await llm_with_tools.ainvoke(messages, config)

# tool calling loop - max 15 iterations to prevent infinite loop
max_iterations = 15
iteration = 0
while response.tool_calls and iteration < max_iterations:
iteration += 1
tool_results = []

for tool_call in response.tool_calls:
# find matching tool and execute it - triggers MCP client to Reactome API
tool = next(
t for t in self.mcp_tools if t.name == tool_call["name"]
)

result = await tool.ainvoke(tool_call["args"])

# tool_call_id links this result back to the specific request the LLM made
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call["id"],
)
)

# send tool results back to LLM for final answer
messages = messages + [response] + tool_results
response = await llm_with_tools.ainvoke(messages, config)

# loop hit max iterations - LLM never gave direct answer
if response.tool_calls:
answer = "I was unable to complete the research in time. Please try rephrasing your question."
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(answer),
],
answer=answer,
)

# LLM gave direct answer
return ReactToMeState(
chat_history=[
HumanMessage(state["user_input"]),
AIMessage(result["answer"]),
AIMessage(response.content),
],
answer=result["answer"],
answer=response.content,
)


def create_reactome_graph(
llm: BaseChatModel,
embedding: Embeddings,
mcp_tools: list | None = None
) -> StateGraph:
return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph
return ReactToMeGraphBuilder(llm, embedding, mcp_tools).uncompiled_graph
Empty file added src/mcp/__init__.py
Empty file.
91 changes: 91 additions & 0 deletions src/mcp/mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import asyncio
import json


class MCPToolError(Exception):
"""Raised when MCP server returns a JSON-RPC error response."""
pass


class MCPClient:
"""
JSON-RPC client for communicating with the Reactome MCP server over stdin/stdout.

Args:
process: Running MCP server subprocess from MCPProcessManager.
timeout: Seconds to wait for a response before raising TimeoutError.
"""

def __init__(self, process: asyncio.subprocess.Process, timeout: float = 30.0):
self.process = process
self.timeout = timeout
self.request_id = 0

async def call(self, method: str, params: dict | None = None) -> dict:
"""
Send a JSON-RPC request and return the result.

Raises:
MCPToolError: If the server returns an error response.
asyncio.TimeoutError: If no response within timeout seconds.
RuntimeError: If the server closes the connection or returns invalid JSON.
"""
if params is None:
params = {}

self.request_id += 1

request = {
"jsonrpc": "2.0",
"id": self.request_id,
"method": method,
"params": params,
}

message = json.dumps(request) + "\n"
self.process.stdin.write(message.encode("utf-8"))
await self.process.stdin.drain()

# Wait for response with timeout so chatbot never hangs indefinitely
response_line = await asyncio.wait_for(
self.process.stdout.readline(),
timeout=self.timeout,
)

if not response_line:
raise RuntimeError("MCP server closed the connection.")

try:
response = json.loads(response_line.decode("utf-8").strip())
except json.JSONDecodeError as e:
raise RuntimeError(f"MCP server returned invalid JSON: {e}")

# JSON-RPC error response — server understood request but returned an error
if "error" in response:
error = response["error"]
raise MCPToolError(
f"MCP error {error.get('code')}: {error.get('message')}"
)

return response.get("result", {})

async def call_tool(self, tool_name: str, arguments: dict | None = None) -> str:
"""
Call a specific MCP tool and return its text output.

Args:
tool_name: Name of the tool e.g. 'reactome_search'.
arguments: Tool arguments as key-value pairs.
"""
if arguments is None:
arguments = {}

result = await self.call(
"tools/call",
{"name": tool_name, "arguments": arguments},
)

# MCP returns content as list of typed blocks — extract text blocks
content = result.get("content", [])
text_parts = [block["text"] for block in content if block.get("type") == "text"]
return "\n".join(text_parts)
Loading