acidjunk / fastapi-postgres-boilerplate

My own fastapi postgres boilerplate
Apache License 2.0
46 stars 9 forks source link

Async #16

Open theobouwman opened 7 months ago

theobouwman commented 7 months ago

Is it possible to use this with async SQLAlchemy?

waza-ari commented 3 months ago

It is possible, I adapted it within my project to work with async SQLAlchemy. A few changes are required though. Just a summary:

core/config.py

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

class Settings(BaseSettings):
    # DB Settings
    database_url: str = Field(description="The Database URL")

    model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

settings = Settings()

core/db.py

from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import Any, AsyncGenerator, Optional
from uuid import uuid4

from sqlalchemy.ext.asyncio import (AsyncAttrs, AsyncSession,
                                    async_scoped_session, create_async_engine)
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from starlette.middleware.base import (BaseHTTPMiddleware,
                                       RequestResponseEndpoint)
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from .config import settings

class Base(AsyncAttrs, DeclarativeBase):
    pass

ENGINE_ARGUMENTS = {
    "connect_args": {"connect_timeout": 10, "options": "-c timezone=UTC"},
    "pool_pre_ping": True,
    "pool_size": 60,
}
SESSION_ARGUMENTS = {
    "autocommit": False,
    "autoflush": True,
    "expire_on_commit": False,
    "class_": AsyncSession,
}

class Database:
    """Setup and contain our database connection.

    This is used to be able to setup the database in an uniform way while allowing easy testing and session management.

    Session management is done using ``scoped_session`` with a special scopefunc, because we cannot use
    threading.local(). Contextvar does the right thing with respect to asyncio and behaves similar to threading.local().
    We only store a random string in the contextvar and let scoped session do the heavy lifting. This allows us to
    easily start a new session or get the existing one using the scoped_session mechanics.
    """

    def __init__(self) -> None:
        self.request_context: ContextVar[str] = ContextVar(
            "request_context", default=""
        )
        self.engine = create_async_engine(settings.database_url, **ENGINE_ARGUMENTS)

        self.session_factory = sessionmaker(bind=self.engine, **SESSION_ARGUMENTS)

        self.scoped_session = async_scoped_session(
            self.session_factory, self._scopefunc
        )

    def _scopefunc(self) -> Optional[str]:
        scope_str = self.request_context.get()
        print(f"Scopefunc: {scope_str}")
        return scope_str

    @property
    def session(self) -> AsyncSession:
        return self.scoped_session()

    @asynccontextmanager
    async def database_scope(self, **kwargs: Any) -> AsyncGenerator["Database", None]:
        """Create a new database session (scope).

        This creates a new database session to handle all the database connection from a single scope (request or workflow).
        This method should typically only been called in request middleware or at the start of workflows.

        Args:
            ``**kwargs``: Optional session kw args for this session
        """
        token = self.request_context.set(str(uuid4()))
        self.scoped_session(**kwargs)
        yield self
        await self.scoped_session.remove()
        self.request_context.reset(token)

class DBSessionMiddleware(BaseHTTPMiddleware):
    def __init__(self, app: ASGIApp, database: Database):
        super().__init__(app)
        self.database = database

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        async with self.database.database_scope():
            response = await call_next(request)
        return response

Adding the middleware in main.py doesn't really change. Keep in mind that there may be issues with lazy loading in asyncio, so for relations you may want to define eager loading if you use them in Pydantic nested models. One example from my budget model, which has a relationship to a category:

from __future__ import annotations

import uuid
from typing import TYPE_CHECKING

from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship

from ..core import Base
from .mixins.base_model import BaseModel

if TYPE_CHECKING:
    from .assignment import Assignment
    from .category import Category

class Budget(Base, BaseModel):

    __tablename__ = "budget"
    __mapper_args__ = {"eager_defaults": True}

    assignments: Mapped[list["Assignment"]] = relationship(
        back_populates="budget", lazy="selectin"
    )
    amount: Mapped[float] = mapped_column(nullable=False)
    name: Mapped[str] = mapped_column(nullable=False)
    category_id: Mapped[uuid.UUID | None] = mapped_column(
        ForeignKey("category.id"), nullable=True
    )
    category: Mapped["Category"] = relationship(
        "Category", back_populates="budgets", lazy="selectin"
    )
    description: Mapped[str | None] = mapped_column(nullable=True)
    amount: Mapped[float] = mapped_column(nullable=False)