laurentS / slowapi

A rate limiter for Starlette and FastAPI
https://pypi.org/project/slowapi/
MIT License
1.24k stars 79 forks source link

The limiters of different routes with the same function name are confused, resulting in multiple checks #173

Open cnjack2024 opened 1 year ago

cnjack2024 commented 1 year ago

Code:

import slowapi

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware

def get_remote_client(request: Request) -> str:
    key = request.client.host or "127.0.0.1"

    return key

limiter = slowapi.Limiter(key_func=get_remote_client)

app = FastAPI()

app.state.limiter = limiter

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class A:
    @app.post("/test1")
    @limiter.limit("1 per 10 second")
    async def test(request: Request):
        print("test1")

        return {"OK": True}

class B:
    @app.post("/test2")
    @limiter.limit("1 per 10 second")
    async def test(request: Request):
        print("test2")

        return {"OK": True}

Solution:

view_func.__name__ -> view_func.__qualname__

def _check_request_limit(
    self,
    request: Request,
    endpoint_func: Optional[Callable[..., Any]],
    in_middleware: bool = True,
) -> None:
    """
    Determine if the request is within limits
    """
    endpoint_url = request["path"] or ""
    view_func = endpoint_func

    endpoint_func_name = (
        f"{view_func.__module__}.{view_func.__qualname__}" if view_func else ""
    )
def __limit_decorator(
    self,
    limit_value: StrOrCallableStr,
    key_func: Optional[Callable[..., str]] = None,
    shared: bool = False,
    scope: Optional[StrOrCallableStr] = None,
    per_method: bool = False,
    methods: Optional[List[str]] = None,
    error_message: Optional[str] = None,
    exempt_when: Optional[Callable[..., bool]] = None,
    cost: Union[int, Callable[..., int]] = 1,
    override_defaults: bool = True,
) -> Callable[..., Any]:
    _scope = scope if shared else None

    def decorator(func: Callable[..., Response]):
        keyfunc = key_func or self._key_func
        name = f"{func.__module__}.{func.__qualname__}"
def exempt(self, obj):
    """
    Decorator to mark a view as exempt from rate limits.
    """
    name = "%s.%s" % (obj.__module__, obj.__qualname__)
def _get_route_name(handler: Callable):
    return f"{handler.__module__}.{handler.__qualname__}"
thentgesMindee commented 10 months ago

Hi! Feel free to open a PR, as unfortunately no-one is working full-time on this package