diff --git a/src/database/setup.py b/src/database/setup.py index f02f379..cfd2306 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -1,11 +1,11 @@ +import functools + +from loguru import logger from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from config import DatabaseConfiguration, get_config -_user_engine = None -_expdb_engine = None - def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine: db_url = URL.create( @@ -16,6 +16,8 @@ def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine: port=db_config.port, database=db_config.database, ) + + logger.info("Creating database engine for {db_url}", db_url=db_url) return create_async_engine( db_url, echo=db_config.echo, @@ -23,26 +25,27 @@ def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine: ) +@functools.cache def user_database() -> AsyncEngine: - global _user_engine # noqa: PLW0603 - if _user_engine is None: - _user_engine = _create_engine(get_config().openml_database) - return _user_engine + return _create_engine(get_config().openml_database) +@functools.cache def expdb_database() -> AsyncEngine: - global _expdb_engine # noqa: PLW0603 - if _expdb_engine is None: - _expdb_engine = _create_engine(get_config().expdb_database) - return _expdb_engine + return _create_engine(get_config().expdb_database) async def close_databases() -> None: """Close all database connections.""" - global _user_engine, _expdb_engine # noqa: PLW0603 - if _user_engine is not None: - await _user_engine.dispose() - _user_engine = None - if _expdb_engine is not None: - await _expdb_engine.dispose() - _expdb_engine = None + for db in (user_database, expdb_database): + if db.cache_info().currsize == 1: + engine = db() + logger.info("Disposing of engine connected to {db_url}", db_url=engine.url) + try: + await engine.dispose() + except Exception: # noqa: BLE001 + logger.exception( + "Issue disposing of database engine for {db_url}", + db_url=engine.url, + ) + db.cache_clear()