Skip to content
10 changes: 9 additions & 1 deletion bin/chat-chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
config: Config | None = Config.from_yaml()

profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
llm_graph = AgentGraph(profiles)
llm_config: str = config.llm if config else "openai/gpt-4o-mini"
embedding_config: str = config.embedding if config else "openai/text-embedding-3-large"
llm_graph = AgentGraph(profiles, llm_config=llm_config, embedding_config=embedding_config)

POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB")
POSTGRES_USER = os.getenv("POSTGRES_USER")
Expand Down Expand Up @@ -81,6 +83,12 @@ async def resume(thread: ThreadDict) -> None:
await static_messages(config, TriggerEvent.on_chat_resume)


@cl.on_app_shutdown
async def on_shutdown() -> None:
"""Explicitly close the connection pool on application shutdown."""
await llm_graph.shutdown()


@cl.on_chat_end
async def end() -> None:
await static_messages(config, TriggerEvent.on_chat_end)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ nltk = "^3.9.1"
[tool.poetry.group.dev.dependencies]
ruff = "^0.7.1"
pytest = "^8.3.3"
pytest-mock = "^3.14.0"
mypy = "^1.13.0"
black = "^24.10.0"
isort = "^5.13.2"
Expand Down
55 changes: 41 additions & 14 deletions src/agent/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
from typing import Any

Expand All @@ -18,32 +17,52 @@
from agent.profiles.base import InputState, OutputState
from util.logging import logging

LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"

if not os.getenv("POSTGRES_LANGGRAPH_DB"):
logging.warning("POSTGRES_LANGGRAPH_DB undefined; falling back to MemorySaver.")
def _build_db_uri() -> str | None:
"""Lazily construct the Postgres URI only when all required env vars are present."""
user = os.getenv("POSTGRES_USER")
password = os.getenv("POSTGRES_PASSWORD")
db = os.getenv("POSTGRES_LANGGRAPH_DB")
if not all([user, password, db]):
return None
return f"postgresql://{user}:{password}@postgres:5432/{db}?sslmode=disable"


class AgentGraph:
def __init__(
self,
profiles: list[ProfileName],
llm_config: str = "openai/gpt-4o-mini",
embedding_config: str = "openai/text-embedding-3-large",
) -> None:
# Get base models
llm: BaseChatModel = get_llm("openai", "gpt-4o-mini")
embedding: Embeddings = get_embedding("openai", "text-embedding-3-large")
provider, _, model = llm_config.partition("/")
emb_provider, _, emb_model = embedding_config.partition("/")

llm: BaseChatModel = get_llm(provider, model)
embedding: Embeddings = get_embedding(emb_provider, emb_model)

self.uncompiled_graph: dict[str, StateGraph] = create_profile_graphs(
profiles, llm, embedding
)

# The following are set asynchronously by calling initialize()
# Set asynchronously by initialize() / async context manager
self.graph: dict[str, CompiledStateGraph] | None = None
self.pool: AsyncConnectionPool[AsyncConnection[dict[str, Any]]] | None = None

def __del__(self) -> None:
if self.pool:
asyncio.run(self.close_pool())
# ------------------------------------------------------------------ #
# Async context manager — preferred lifecycle #
# ------------------------------------------------------------------ #

async def __aenter__(self) -> "AgentGraph":
self.graph = await self.initialize()
return self

async def __aexit__(self, *_: object) -> None:
await self.shutdown()

# ------------------------------------------------------------------ #
# Initialisation helpers #
# ------------------------------------------------------------------ #

async def initialize(self) -> dict[str, CompiledStateGraph]:
checkpointer: BaseCheckpointSaver[str] = await self.create_checkpointer()
Expand All @@ -53,10 +72,12 @@ async def initialize(self) -> dict[str, CompiledStateGraph]:
}

async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
if not os.getenv("POSTGRES_LANGGRAPH_DB"):
uri = _build_db_uri()
if uri is None:
logging.warning("POSTGRES_LANGGRAPH_DB undefined; falling back to MemorySaver.")
return MemorySaver()
self.pool = AsyncConnectionPool(
conninfo=LANGGRAPH_DB_URI,
conninfo=uri,
max_size=20,
open=False,
timeout=30,
Expand All @@ -70,9 +91,15 @@ async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
await checkpointer.setup()
return checkpointer

async def close_pool(self) -> None:
async def shutdown(self) -> None:
"""Explicit lifecycle teardown — call this instead of relying on __del__."""
if self.pool:
await self.pool.close()
self.pool = None

# ------------------------------------------------------------------ #
# Invocation #
# ------------------------------------------------------------------ #

async def ainvoke(
self,
Expand Down
64 changes: 57 additions & 7 deletions src/agent/profiles/cross_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Annotated, Any, Literal, TypedDict

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -15,8 +15,11 @@
create_uniprot_rewriter_w_reactome
from agent.tasks.cross_database.summarize_reactome_uniprot import \
create_reactome_uniprot_summarizer
from agent.tasks.flow_reasoner import create_flow_reasoner
from retrievers.reactome.rag import create_reactome_rag
from retrievers.uniprot.rag import create_uniprot_rag
from tools.reactome_topology import ReactomeTopologyTool
import re


class CrossDatabaseState(BaseState):
Expand All @@ -28,6 +31,8 @@ class CrossDatabaseState(BaseState):
uniprot_answer: str # LLM-generated answer from UniProt
uniprot_completeness: str # LLM-assessed completeness of the UniProt answer

flow_context: str # Topological flow data fetched by identify_flow() for mechanistic queries


class CrossDatabaseGraphBuilder(BaseGraphBuilder):
def __init__(
Expand All @@ -47,6 +52,8 @@ def __init__(
self.summarize_final_answer = create_reactome_uniprot_summarizer(
llm, streaming=True
)
self.flow_reasoner = create_flow_reasoner(llm)
self.topology_tool = ReactomeTopologyTool()

# Create graph
state_graph = StateGraph(CrossDatabaseState)
Expand All @@ -62,6 +69,8 @@ def __init__(
state_graph.add_node("rewrite_uniprot_answer", self.rewrite_uniprot_answer)
state_graph.add_node("assess_completeness", self.assess_completeness)
state_graph.add_node("decide_next_steps", self.decide_next_steps)
state_graph.add_node("identify_flow", self.identify_flow)
state_graph.add_node("verify_mechanism", self.verify_mechanism)
state_graph.add_node("generate_final_response", self.generate_final_response)
state_graph.add_node("postprocess", self.postprocess)
# Set up edges
Expand All @@ -84,8 +93,11 @@ def __init__(
"perform_web_search": "generate_final_response",
"rewrite_reactome_query": "rewrite_reactome_query",
"rewrite_uniprot_query": "rewrite_uniprot_query",
"identify_flow": "identify_flow",
},
)
state_graph.add_edge("identify_flow", "verify_mechanism")
state_graph.add_edge("verify_mechanism", "generate_final_response")
state_graph.add_edge("rewrite_reactome_query", "rewrite_reactome_answer")
state_graph.add_edge("rewrite_uniprot_query", "rewrite_uniprot_answer")
state_graph.add_edge("rewrite_reactome_answer", "generate_final_response")
Expand Down Expand Up @@ -203,14 +215,23 @@ async def assess_completeness(
uniprot_completeness=uniprot_completeness.binary_score,
)

async def decide_next_steps(self, state: CrossDatabaseState) -> Literal[
"generate_final_response",
"perform_web_search",
"rewrite_reactome_query",
"rewrite_uniprot_query",
]:
async def decide_next_steps(self, state: CrossDatabaseState) -> Literal["identify_flow", "generate_final_response", "perform_web_search", "rewrite_reactome_query", "rewrite_uniprot_query"]:
"""Decide the next step based on the research results and context."""
user_query = state.get("rephrased_input", "").lower()
reactome_answer = state.get("reactome_answer", "")
uniprot_answer = state.get("uniprot_answer", "")

# Tightened keyword matching for mechanistic flow detection
flow_pattern = r"\b(after|consequence|downstream|flow|precede|trigger|following|upstream|mechanism)\b"
is_mechanistic = bool(re.search(flow_pattern, user_query))

if is_mechanistic and reactome_answer and "error" not in reactome_answer.lower():
return "identify_flow"

reactome_complete = state["reactome_completeness"] != "No"
uniprot_complete = state["uniprot_completeness"] != "No"


if reactome_complete and uniprot_complete:
return "generate_final_response"
elif not reactome_complete and uniprot_complete:
Expand All @@ -220,6 +241,35 @@ async def decide_next_steps(self, state: CrossDatabaseState) -> Literal[
else:
return "perform_web_search"

async def identify_flow(self, state: CrossDatabaseState, config: RunnableConfig) -> dict[str, Any]:
# Extract Reactome Stable IDs (e.g., R-HSA-123456) from the reactome_answer
id_pattern = re.compile(r"R-[A-Z]{3}-\d+")
reactome_ans = state.get("reactome_answer", "")
st_ids = list(set(id_pattern.findall(reactome_ans)))

flow_context: str = ""
for st_id in st_ids[:5]: # Limit to top 5 IDs for token safety
context = self.topology_tool.get_flow_context(st_id)
if context:
flow_context += f"\n---\n{context}"

return {"flow_context": flow_context}

async def verify_mechanism(self, state: CrossDatabaseState, config: RunnableConfig) -> dict[str, Any]:
flow_ctx = state.get("flow_context")
if not flow_ctx:
return {}

verified_answer: str = await self.flow_reasoner.ainvoke(
{
"input": state["rephrased_input"],
"initial_answer": state["reactome_answer"],
"flow_context": flow_ctx,
},
config
)
return {"reactome_answer": verified_answer}

async def generate_final_response(
self, state: CrossDatabaseState, config: RunnableConfig
) -> CrossDatabaseState:
Expand Down
41 changes: 41 additions & 0 deletions src/agent/tasks/flow_reasoner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

flow_reasoning_message = """
You are a senior curator and mechanistic reasoner for the Reactome Pathway Knowledgebase.
Your task is to take a biological explanation and the corresponding topological data (next/previous steps, participants) from the Reactome Graph to verify and enrich the answer.

Context provided:
1. Initial Explanation: The draft answer generated by the RAG system.
2. Topological Data: Raw data from the Reactome Graph about the reactions and pathways mentioned.

Objective:
- Verify the sequence of events: Does the Graph data confirm that Reaction A leads to Reaction B as described in the initial answer?
- Identify missing mechanistic links: If there's a gap between two steps in the explanation, use the topological data to fill it (e.g., mention a missing intermediate metabolite or catalyst).
- Correct errors: If the initial answer claimed a protein is an input but the graph says it's a catalyst, correct it.

Output Requirements:
- Provide a refined, highly accurate mechanistic description of the pathway/process.
- Highlight the "flow" of information or matter (e.g., "First X happens, which triggers Y, resulting in Z").
- Maintain all citations from the original context and add new ones if new IDs from the topology are used.
- If the topological data contradicts the initial findings, prioritize the topological data as the 'ground truth' of the Reactome Graph.

Strict Rule: Focus ONLY on the biological mechanism and flow. Do not add generic filler.
"""

flow_reasoning_prompt = ChatPromptTemplate.from_messages(
[
("system", flow_reasoning_message),
(
"human",
"Initial Explanation: {initial_answer} \n\n Topological Data: \n {flow_context} \n\n User Question: {input}",
),
]
)

def create_flow_reasoner(llm: BaseChatModel) -> Runnable:
return (flow_reasoning_prompt | llm | StrOutputParser()).with_config(
run_name="flow_reasoning"
)
27 changes: 19 additions & 8 deletions src/retrievers/csv_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Annotated, Any, Coroutine, TypedDict

import chromadb.config
from langchain.chains.query_constructor.schema import AttributeInfo
from util.langchain_compat import AttributeInfo
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_chroma.vectorstores import Chroma
Expand Down Expand Up @@ -71,15 +73,25 @@ def create_bm25_chroma_ensemble_retriever(
*,
descriptions_info: dict[str, str],
field_info: dict[str, list[AttributeInfo]],
) -> MergerRetriever:
return HybridRetriever.from_subdirectory(
) -> ContextualCompressionRetriever:
"""Create a HybridRetriever wrapped with EmbeddingsFilter-based
Contextual Compression to filter out low-relevance documents before
they are passed to the LLM, improving answer quality and reducing
hallucinations caused by noisy retrieval."""
base_retriever = HybridRetriever.from_subdirectory(
llm,
embedding,
embeddings_directory,
descriptions_info=descriptions_info,
field_info=field_info,
include_original=True,
)
embeddings_filter = EmbeddingsFilter(
embeddings=embedding, similarity_threshold=0.76
)
return ContextualCompressionRetriever(
base_compressor=embeddings_filter, base_retriever=base_retriever
)


class RetrieverDict(TypedDict):
Expand Down Expand Up @@ -177,7 +189,7 @@ def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]:
)
},
)
doc_lists.append(bm25_docs + vector_docs)
doc_lists.extend([bm25_docs, vector_docs])
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
return subdirectory_docs

Expand Down Expand Up @@ -214,9 +226,8 @@ async def aretrieve_documents(
subdirectory_docs: list[Document] = []
for subdir_results in subdirectory_results.values():
results_iter = iter(await asyncio.gather(*subdir_results))
doc_lists: list[list[Document]] = [
bm25_results + vector_results
for bm25_results, vector_results in zip(results_iter, results_iter)
]
doc_lists: list[list[Document]] = []
for bm25_results, vector_results in zip(results_iter, results_iter):
doc_lists.extend([bm25_results, vector_results])
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
return subdirectory_docs
2 changes: 1 addition & 1 deletion src/retrievers/reactome/metadata_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.chains.query_constructor.base import AttributeInfo
from util.langchain_compat import AttributeInfo

pathway_id_description = "A Reactome Identifier unique to each pathway. A pathway name may appear multiple times in the dataset\
This ID allows for the specific identification and exploration of each pathway's details within the Reactome Database."
Expand Down
2 changes: 1 addition & 1 deletion src/retrievers/uniprot/metadata_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.chains.query_constructor.base import AttributeInfo
from util.langchain_compat import AttributeInfo

uniprot_descriptions_info = {
"uniprot_data": "Contains detailed protein information about gene names, protein names, subcellular localizations, family classifications, biological pathway associations, domains, motifs, disease associations, and functional descriptions. ",
Expand Down
Loading