dmontagu / fastapi-utils

Reusable utilities for FastAPI
MIT License
1.83k stars 163 forks source link

[FEATURE] Async FastAPISessionMaker #287

Open ct-jby opened 3 months ago

ct-jby commented 3 months ago

Hello

Describe the solution you'd like I would like to be able to use async database session manager. A async version of FastAPISessionMaker.

ct-jby commented 3 months ago

Maybe something like that ?

from collections.abc import Iterator
from contextlib import asynccontextmanager
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncSession, async_sessionmaker, async_scoped_session
from typing import Any, AsyncIterator

class SQLAlchemyAsyncSessionMaker:
    """
    A convenience class for managing a (cached) sqlalchemy ORM engine and sessionmaker.

    Intended for use creating ORM sessions injected into endpoint functions by FastAPI.
    """

    def __init__(self, database_uri: str, connect_args: dict[str, Any] | None = None):
        """
        `database_uri` should be any sqlalchemy-compatible database URI.

        In particular, `sqlalchemy.create_engine(database_uri)` should work to create an engine.

        Typically, this would look like:

            "<scheme>://<user>:<password>@<host>:<port>/<database>"

        A concrete example looks like "postgresql://db_user:password@db:5432/app"
        """
        self.database_uri = database_uri
        self.connect_args: connect_args

        self._cached_async_engine: AsyncEngine | None = None
        self._cached_async_sessionmaker: async_sessionmaker | None = None

    @property
    def cached_async_engine(self) -> AsyncEngine:
        """
        Returns a lazily-cached sqlalchemy async engine for the instance's database_uri.
        """
        async_engine = self._cached_async_engine
        if async_engine is None:
            async_engine = self.get_new_async_engine()
            self._cached_async_engine = async_engine
        return async_engine

    @property
    def cached_async_sessionmaker(self) -> async_sessionmaker:
        """
        Returns a lazily-cached sqlalchemy async sessionmaker using the instance's (lazily-cached) async engine.
        """
        async_sessionmaker = self._cached_async_sessionmaker
        if async_sessionmaker is None:
            async_sessionmaker = self.get_new_async_sessionmaker(self.cached_async_engine)
            self._cached_sessionmaker = async_sessionmaker
        return async_sessionmaker

    def get_new_async_engine(self) -> AsyncEngine:
        """
        Returns a new sqlalchemy async engine using the instance's database_uri.
        """
        return get_async_engine(self.database_uri, self.connect_args)

    def get_new_async_sessionmaker(
        self, async_engine: AsyncEngine | None
    ) -> async_sessionmaker:
        """
        Returns a new async sessionmaker for the provided sqlalchemy async engine. If no engine is provided, the
        instance's (lazily-cached) async engine is used.
        """
        async_engine = async_engine or self.cached_async_engine
        return get_async_sessionmaker_for_async_engine(async_engine)

    async def get_async_db(self) -> AsyncIterator[AsyncSession]:
        """
        A generator function that yields a sqlalchemy orm session and cleans up the session once resumed after yielding.

        Can be used directly as a context-manager FastAPI dependency, or yielded from inside a separate dependency.
        """
        #yield from _get_async_db(self._cached_async_sessionmaker)
        async for async_session in _get_async_db(self._cached_async_sessionmaker):
            yield async_session

    @asynccontextmanager
    async def context_async_session(self) -> AsyncIterator[AsyncSession]:
        """
        A context-manager wrapped version of the `get_async_db` method.

        This makes it possible to get a context-managed orm async session for the relevant database_uri without
        needing to rely on FastAPI's dependency injection.

        Usage looks like:

            async_session_maker = SQLAlchemyAsyncSessionMaker(database_uri)
            with async_session_maker.context_async_session() as async_session:
                async_session.query(...)
                ...
        """
        #yield from self.get_async_db()
        async with self.get_async_db() as async_session:
            yield async_session

    def reset_cache(self) -> None:
        """
        Resets the engine and sessionmaker caches.

        After calling this method, the next time you try to use the cached engine or sessionmaker,
        new ones will be created.
        """
        self._cached_async_engine = None
        self._cached_async_sessionmaker = None

def get_async_engine(
    uri: str, connect_args: dict[str, Any] | None = None
) -> AsyncEngine:
    """
    Returns a sqlalchemy async engine with pool_pre_ping enabled.

    This function may be updated over time to reflect recommended async engine configuration for use with FastAPI.
    """
    return create_async_engine(uri, pool_pre_ping=True, connect_args=connect_args)

def get_async_sessionmaker_for_async_engine(async_engine: AsyncEngine) -> async_sessionmaker:
    """
    Returns a sqlalchemy async sessionmaker for the provided engine with recommended configuration settings.

    This function may be updated over time to reflect recommended async sessionmaker configuration for use with FastAPI.
    """
    return async_sessionmaker(autocommit=False, autoflush=False, bind=async_engine)

@asynccontextmanager
async def context_async_session(async_engine: AsyncEngine) -> AsyncIterator[AsyncSession]:
    """
    This async contextmanager yields a managed async session for the provided engine.

    Usage is similar to `SQLAlchemyAsyncSessionMaker.context_async_session`, except that you have to provide the engine to use.

    A new async sessionmaker is created for each call, so the SQLAlchemyAsyncSessionMaker.context_async_session
    method may be preferable in performance-sensitive contexts.
    """
    async_sessionmaker = get_async_sessionmaker_for_async_engine(async_engine)
    async for async_session in _get_async_db(async_sessionmaker):
        yield async_session

async def _get_async_db(async_sessionmaker: async_sessionmaker) -> AsyncIterator[AsyncSession]:
    """
    A generator function that yields an ORM async session using the provided async sessionmaker, and cleans it up when resumed.
    """
    async_session = async_sessionmaker()
    try:
        yield async_session
        await async_session.commit()
    except Exception as exc:
        await async_session.rollback()
        raise exc
    finally:
        await async_session.close()