Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 204 additions & 14 deletions src/agent/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""
Agent graph module with proper async lifecycle management.

This module provides the AgentGraph class for managing LangGraph agents
with PostgreSQL checkpointing and connection pool management.
"""

import asyncio
import os
from typing import Any
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, ClassVar

from langchain_core.callbacks.base import Callbacks
from langchain_core.embeddings import Embeddings
Expand All @@ -18,17 +26,70 @@
from agent.profiles.base import InputState, OutputState
from util.logging import logging

LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"

if not os.getenv("POSTGRES_LANGGRAPH_DB"):
logging.warning("POSTGRES_LANGGRAPH_DB undefined; falling back to MemorySaver.")
def _build_langgraph_db_uri() -> str | None:
"""
Build the PostgreSQL connection URI lazily.

Returns None if required environment variables are not set,
avoiding construction of invalid URIs with 'None' values.
"""
user = os.getenv("POSTGRES_USER")
password = os.getenv("POSTGRES_PASSWORD")
db = os.getenv("POSTGRES_LANGGRAPH_DB")

if not all([user, password, db]):
logging.warning(
"PostgreSQL environment variables not fully configured. "
"POSTGRES_USER, POSTGRES_PASSWORD, and POSTGRES_LANGGRAPH_DB are required. "
"Falling back to MemorySaver."
)
return None

return f"postgresql://{user}:{password}@postgres:5432/{db}?sslmode=disable"


class AgentGraph:
"""
LangGraph agent with PostgreSQL checkpointing and proper resource management.

This class manages the lifecycle of LangGraph agents with optional PostgreSQL
checkpointing. It supports both async context manager usage (recommended) and
explicit initialization/cleanup.

Usage (recommended - async context manager):
```python
async with AgentGraph(profiles=["cross_database"]) as graph:
result = await graph.ainvoke(user_input, profile, ...)
```

Usage (explicit):
```python
graph = AgentGraph(profiles=["cross_database"])
try:
await graph.initialize()
result = await graph.ainvoke(user_input, profile, ...)
finally:
await graph.close()
```
"""

# Class-level tracking for debugging
_instance_count: ClassVar[int] = 0

def __init__(
self,
profiles: list[ProfileName],
) -> None:
"""
Initialize the agent graph.

Args:
profiles: List of profile names to create graphs for.
"""
AgentGraph._instance_count += 1
self._instance_id = AgentGraph._instance_count

# Get base models
llm: BaseChatModel = get_llm("openai", "gpt-4o-mini")
embedding: Embeddings = get_embedding("openai", "text-embedding-3-large")
Expand All @@ -40,23 +101,67 @@ def __init__(
# The following are set asynchronously by calling initialize()
self.graph: dict[str, CompiledStateGraph] | None = None
self.pool: AsyncConnectionPool[AsyncConnection[dict[str, Any]]] | None = None
self._initialized: bool = False
self._closed: bool = False

logging.debug(f"AgentGraph instance {self._instance_id} created")

def __del__(self) -> None:
if self.pool:
asyncio.run(self.close_pool())

# REMOVED: __del__ with asyncio.run() — this was causing production crashes!
#
# def __del__(self) -> None:
# if self.pool:
# asyncio.run(self.close_pool()) # RuntimeError in async context!


async def __aenter__(self) -> "AgentGraph":
"""Async context manager entry — initializes the graph."""
await self.initialize()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Async context manager exit — closes resources."""
await self.close()

async def initialize(self) -> dict[str, CompiledStateGraph]:
checkpointer: BaseCheckpointSaver[str] = await self.create_checkpointer()
return {
"""
Initialize the graph with checkpointing.

This method is idempotent — calling it multiple times is safe.

Returns:
Dict mapping profile names to compiled state graphs.
"""
if self._initialized:
return self.graph or {}

if self._closed:
raise RuntimeError("Cannot initialize a closed AgentGraph")

checkpointer: BaseCheckpointSaver[str] = await self._create_checkpointer()
self.graph = {
profile: graph.compile(checkpointer=checkpointer)
for profile, graph in self.uncompiled_graph.items()
}
self._initialized = True

logging.debug(f"AgentGraph instance {self._instance_id} initialized")
return self.graph

async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
if not os.getenv("POSTGRES_LANGGRAPH_DB"):
async def _create_checkpointer(self) -> BaseCheckpointSaver[str]:
"""
Create the appropriate checkpointer based on environment configuration.

Returns:
MemorySaver if PostgreSQL is not configured, otherwise AsyncPostgresSaver.
"""
db_uri = _build_langgraph_db_uri()

if db_uri is None:
return MemorySaver()

self.pool = AsyncConnectionPool(
conninfo=LANGGRAPH_DB_URI,
conninfo=db_uri,
max_size=20,
open=False,
timeout=30,
Expand All @@ -66,13 +171,48 @@ async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
},
)
await self.pool.open()

checkpointer = AsyncPostgresSaver(self.pool)
await checkpointer.setup()

logging.debug(f"PostgreSQL checkpointer created for instance {self._instance_id}")
return checkpointer

async def close_pool(self) -> None:
async def close(self) -> None:
"""
Close the connection pool and release resources.

This method is idempotent — calling it multiple times is safe.
"""
if self._closed:
return

if self.pool:
await self.pool.close()
try:
await self.pool.close()
logging.debug(f"Connection pool closed for instance {self._instance_id}")
except Exception as e:
logging.warning(f"Error closing connection pool: {e}")
finally:
self.pool = None

self._closed = True
logging.debug(f"AgentGraph instance {self._instance_id} closed")

# Backwards compatibility alias
async def close_pool(self) -> None:
"""Alias for close() — kept for backwards compatibility."""
await self.close()

@property
def is_initialized(self) -> bool:
"""Check if the graph has been initialized."""
return self._initialized

@property
def is_closed(self) -> bool:
"""Check if the graph has been closed."""
return self._closed

async def ainvoke(
self,
Expand All @@ -83,10 +223,32 @@ async def ainvoke(
thread_id: str,
enable_postprocess: bool = True,
) -> OutputState:
"""
Invoke the agent with a user query.

Args:
user_input: The user's query text.
profile: Which agent profile to use.
callbacks: LangChain callbacks for streaming/logging.
thread_id: Unique identifier for the conversation thread.
enable_postprocess: Whether to run postprocessing (web search fallback).

Returns:
OutputState containing the agent's response and metadata.

Raises:
RuntimeError: If the graph is closed or not initialized.
"""
if self._closed:
raise RuntimeError("Cannot invoke on a closed AgentGraph")

if self.graph is None:
self.graph = await self.initialize()

if profile not in self.graph:
logging.warning(f"Profile '{profile}' not found, returning empty response")
return OutputState()

result: OutputState = await self.graph[profile].ainvoke(
InputState(user_input=user_input),
config=RunnableConfig(
Expand All @@ -98,3 +260,31 @@ async def ainvoke(
),
)
return result


@asynccontextmanager
async def create_agent_graph(
profiles: list[ProfileName],
) -> AsyncIterator[AgentGraph]:
"""
Factory function to create an AgentGraph with automatic resource management.

This is the recommended way to create and use an AgentGraph:

```python
async with create_agent_graph(["cross_database"]) as graph:
result = await graph.ainvoke(...)
```

Args:
profiles: List of profile names to create graphs for.

Yields:
An initialized AgentGraph instance.
"""
graph = AgentGraph(profiles=profiles)
try:
await graph.initialize()
yield graph
finally:
await graph.close()