long2ice / fastapi-limiter

A request rate limiter for fastapi
https://github.com/long2ice/fastapi-limiter
Apache License 2.0
487 stars 53 forks source link

Usage of library as a middleware utilitary #59

Open brunolnetto opened 1 month ago

brunolnetto commented 1 month ago

Hi guys,

I implemented a custom middleware based on fastapi-limiter. The implementation is here, but I also paste it below. The idea is: retrieve either username if authentication token is provided, or its IP address and user this logic to build the user identifier function. This part is working well. What is not working is the limiter call await limiter(request=request, response=Response()). Am I doing something incorrect according to the supposed library behavior? Thanks! :)

from fastapi import FastAPI, Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Callable, Awaitable
from redis import asyncio as aioredis
from typing import Union

from backend.app.utils.throttling import ip_identifier
from backend.app.utils.request import get_token, get_route
from backend.app.utils.throttling import ip_identifier
from backend.app.base.auth import get_current_user
from backend.app.base.exceptions import MissingTokenException, TooManyRequestsException
from backend.app.base.config import settings
from backend.app.data.auth import ROLES_METADATA
from backend.app.base.logging import logger

class RateLimiterPolicy:
    def __init__(
        self, 
        times: int = 5, 
        hours: int = 0,     
        minutes: int = 1, 
        seconds: int = 0, 
        milliseconds: int = 0
    ):
        self.times = times
        self.hours = hours
        self.minutes = minutes
        self.seconds = seconds
        self.milliseconds = milliseconds

async def init_redis_pool():
    try:
        redis = await aioredis.Redis.from_url(settings.redis_url)
        logger.info("Redis pool initialized successfully.")
        return redis
    except Exception as e:
        logger.error(f"Failed to initialize Redis pool: {e}")
        raise e

async def init_rate_limiter():
    redis = await init_redis_pool()
    await FastAPILimiter.init(redis)

    logger.info("Rate limiter initialized!")

# Given rate limiter, find throughput
def get_throughput(rate_limiter: RateLimiterPolicy):
    times = rate_limiter.times
    interval_seconds = rate_limiter.hours * 3600 + \
        rate_limiter.minutes * 60 + \
        rate_limiter.seconds + \
        rate_limiter.milliseconds / 1000

    return times / interval_seconds

def get_rate_limiter(
    user_identifier: Union[str, None], 
    policy: RateLimiterPolicy = RateLimiterPolicy()
):
    return RateLimiter(
        times=policy.times, 
        hours=policy.hours,
        minutes=policy.minutes,
        seconds=policy.seconds,
        milliseconds=policy.milliseconds,
        identifier=user_identifier
    )

async def get_user_rate_policy(request: Request) -> RateLimiterPolicy:
    route = get_route(request)

    # Check if the route requires authentication
    if settings.route_requires_authentication(route):
        token = get_token(request)

        if not token:
            raise MissingTokenException()

        current_user = await get_current_user(token)
        roles = current_user.user_roles
        rate_policies = [
            ROLES_METADATA[role]['rate_policy'] for role in roles
        ]

        # Get the most permissive rate policy
        rate_policy = max(rate_policies, key=get_throughput)

    else:
        # Default rate policy for non-authenticated routes
        rate_policy = RateLimiterPolicy()

    return rate_policy  

async def get_user_identity(request: Request) -> str:
    route = get_route(request)

    # Check if the route requires authentication
    if settings.route_requires_authentication(route):
        token = get_token(request)

        if not token:
            raise MissingTokenException()

        current_user = await get_current_user(token)
        user_identity = current_user.user_username

    else:
        # Handle non-authenticated routes
        user_identity = await ip_identifier(request)

    return user_identity

async def get_user_limiter(request: Request) -> RateLimiter:
    user_rate_policy = await get_user_rate_policy(request)
    return get_rate_limiter(get_user_identity, user_rate_policy)

class RateLimitMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        try:
            limiter = await get_user_limiter(request)

            # Call the rate limiter, it will raise an exception if the limit is exceeded
            await limiter(request=request, response=Response())

            response = await call_next(request)

            return response

        except TooManyRequestsException as e:
            user_identity = await get_user_identity(request)
            logger.warning(f"Rate limit exceeded for {user_identity}: {str(e)}")
            return Response(status_code=429, content="Too many requests")

        except Exception as e:
            logger.error(f"Unexpected error in rate limiting middleware: {str(e)}")
            return Response(status_code=500, content="Internal Server Error")