|
|
""" |
|
|
SQLAlchemy base classes and database configuration. |
|
|
|
|
|
This module provides the declarative base and database engine setup. |
|
|
""" |
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine |
|
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker |
|
|
|
|
|
from app.config import get_settings |
|
|
|
|
|
settings = get_settings() |
|
|
|
|
|
|
|
|
engine = None |
|
|
AsyncSessionLocal = None |
|
|
|
|
|
|
|
|
def _ensure_engine_and_session(): |
|
|
"""Initialize engine and session factory if not already created. |
|
|
|
|
|
This avoids importing async drivers (e.g., asyncpg) during module import, |
|
|
which breaks tests that patch out DB dependencies. |
|
|
""" |
|
|
global engine, AsyncSessionLocal |
|
|
if engine is None: |
|
|
try: |
|
|
engine = create_async_engine( |
|
|
settings.database_url, |
|
|
echo=settings.debug, |
|
|
pool_size=settings.db_pool_size, |
|
|
max_overflow=settings.db_max_overflow, |
|
|
pool_pre_ping=True, |
|
|
pool_recycle=3600, |
|
|
) |
|
|
AsyncSessionLocal = sessionmaker( |
|
|
engine, |
|
|
class_=AsyncSession, |
|
|
expire_on_commit=False, |
|
|
autocommit=False, |
|
|
autoflush=False, |
|
|
) |
|
|
except ModuleNotFoundError as e: |
|
|
|
|
|
|
|
|
if "asyncpg" in str(e): |
|
|
raise RuntimeError( |
|
|
"Database driver 'asyncpg' is not installed. This is expected in unit tests that patch get_db. " |
|
|
"If you are running the app for real, install asyncpg or configure a supported driver." |
|
|
) from e |
|
|
raise |
|
|
|
|
|
|
|
|
class Base(DeclarativeBase): |
|
|
"""Base class for all SQLAlchemy models.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
async def get_db() -> AsyncSession: |
|
|
""" |
|
|
Dependency that provides a database session. |
|
|
|
|
|
Yields: |
|
|
AsyncSession: Database session |
|
|
|
|
|
Usage: |
|
|
async def my_route(db: AsyncSession = Depends(get_db)): |
|
|
# Use db session |
|
|
pass |
|
|
""" |
|
|
|
|
|
_ensure_engine_and_session() |
|
|
assert AsyncSessionLocal is not None |
|
|
async with AsyncSessionLocal() as session: |
|
|
try: |
|
|
yield session |
|
|
await session.commit() |
|
|
except Exception: |
|
|
await session.rollback() |
|
|
raise |
|
|
finally: |
|
|
await session.close() |
|
|
|