mbg / wai-rate-limit

Rate limiting for Servant and as WAI middleware
MIT License
12 stars 3 forks source link

Add the ability to specify custom error body #5

Open chshersh opened 2 years ago

chshersh commented 2 years ago

First of all, thanks for writing this great library! 😊 It works like clockwork 🕐

While using servant-rate-limit, I stumbled on issue about not being able to specify custom error body when the exception is thrown. It's empty by default and this results in some JSON parsing errors. And this is hardcoded here:

https://github.com/mbg/wai-rate-limit/blob/4fca1af8695e6eb7b199c2c7476c483366605ffa/servant-rate-limit/src/Servant/RateLimit/Server.hs#L59-L64

It would be great to be able to specify custom error body (and headers). Here is the design I came up with.

The main idea is to introduce a typeclass that allows to return a custom error body and list of headers based on request.

-- | A typeclass for types that tell how to produce error body and relevant headers from 'Request' in
-- errors with the 429 code (Rate Limit Exceeded).
class HasRateLimitErrBody err where
  getErrBody :: Request -> IO err
  mkErrBody  :: err -> Request -> (ByteString, [Header])

A few possible instances of the HasRateLimitErrBody typeclass:

-- | Empty error body with no text and headers
data EmptyErrorBody = EmptyErrorBody

instance HasRateLimitErrBody EmptyErrorBody where
  getErrBodySetter _ = pure EmptyErrorBody
  mkErrBody _ _ = ("", [])  -- to emulate the existing behaviour of 'RateLimit'

-- | Simple hardcoded error body as JSON
data SimpleJsonErrBody = SimpleJsonErrBody

instance HasRateLimitErrBody SimpleJsonErrBody where
  getErrBodySetter _ = pure SimpleJsonErrBody
  mkErrBody _ _ =
    let errBody =     Aeson.encode $
          Aeson.object
            [ ( "message",
                "We received too many requests from your device in a short time, please try again in a few minutes."
              )
            ]
        headers = [(hContentType, renderHeader $ contentType (Proxy @JSON))]
    in (errBody, headers)

After that, we can implement a data type similar to 'RateLimitbut with customerrBody`.

data RateLimitCustom errBody strategy policy

You may want to introduce a breaking change in a form of changing the existing RateLimit to this one or reimplementing RateLimit as type RateLimit = RateLimitCustom EmptyErrorBody but I personally tend to avoid introducing breaking changes (at least unless there's a clear migration guide).

After that, it's possible to implement a HasServer instance for newly introduced RateLimitCustom. It's almost the same as the existing instance with a few changes:

instance
    ( HasServer api ctx
    , HasContextEntry ctx (Backend key)
+   , HasRateLimitErrBody errBody,
    , HasRateLimitStrategy strategy
    , HasRateLimitPolicy policy
    , key ~ RateLimitPolicyKey policy
-   ) => HasServer (RateLimit strategy policy :> api) ctx
+   ) => HasServer (RateLimitCustom errBody strategy policy :> api) ctx
    where

    type ServerT (RateLimit strategy policy :> api) m = ServerT api m

    hoistServerWithContext _ pc nt s =
        hoistServerWithContext (Proxy :: Proxy api) pc nt s

    route _ context subserver = do
        -- retrieve the backend from the Servant context
        let backend = getContextEntry context

        -- retrieve the rate-limiting policy used to identify clients
        let policy = policyGetIdentifier @policy

        -- retrieve the rate-limiting strategy used to limit access
        let strategy = strategyValue @strategy @key backend policy

        let rateCheck = withRequest $ \req -> do
                -- apply the rate-limiting strategy to the request
                allowRequest <- liftIO $ strategyOnRequest strategy req

                -- fail if the rate limit has been exceeded
                unless allowRequest $ do
+                   errBodySetter <- liftIO $ getErrBodySetter @errBody req
+                   let (customErrBody, customHeaders) = mkErrBody errBodySetter req
                    delayedFailFatal $ ServerError{
                        errHTTPCode = 429,
                        errReasonPhrase = "Rate limit exceeded",
-                       errBody = "",
+                       errBody = customErrBody,
-                       errHeaders = []
+                       errHeaders = customHeaders
                    }

        -- add the check for whether the rate limit has been exceeded to the
        -- server and return it
        route (Proxy :: Proxy api) context $
            subserver `addAcceptCheck` rateCheck

If this is something you would like to have in your library, I'm happy to contribute the implementation 🤗

mbg commented 2 years ago

Hi @chshersh! I am glad to hear that you are having a good experience with this library so far 😄 Thanks for raising this issue and exploring a potential solution in so much detail as well! That's very much appreciated. I will think a little about this design and then get back to you soon.

In the meantime, I am wondering to what extent the servant-errors package can address your needs here? In theory, that should allow you to just drop in a errorMw @JSON @'["error", "status"] middleware to get a similar effect.

I still see merit in the ability to specify custom messages or headers (e.g. to include information about how much capacity is left etc., similar to what you get from the GitHub API) though.

chshersh commented 2 years ago

@mbg Don't worry about replying urgently 🤗 This issue is not urgent as I've already implemented the above mentioned workaround. Just wanted to upstream the implementation if this is something you want to see in your library. Take as much time as you want to come up with the design that satisfies you 🙂

servant-errors is a wonderful package! It does look very close to what I need. However, IIUC, servant-errors provides a mechanism by specifying error messages in a single place via the HasErrorBody typeclass by pattern matching on StatusCode and Text error message. Which is not ideal for my use case as I would like to have different error messages for different endpoints (e.g. You've exceeded rate limit for phone verification and You've exceeded rate limit for search queries).

I could in theory circumvent this problem by implementing my own naming scheme for errors and then pattern matching on a string in the encodeError function from servant-errors. But this looks sub-optimal to me and, most importantly, I still need the ability to provide custom error messages from servant-rate-limit to achieve this 😅

Here is the current behaviour without servant-errors:

$ curl -XPOST -H "Content-Type: application/json" localhost:8001/ -d '{"msg": "Dmitrii"}' -v
Note: Unnecessary use of -X or --request, POST is already inferred.
*   Trying 127.0.0.1:8001...
* Connected to localhost (127.0.0.1) port 8001 (#0)
> POST / HTTP/1.1
> Host: localhost:8001
> User-Agent: curl/7.81.0
> Accept: */*
> Content-Type: application/json
> Content-Length: 18
> 
* Mark bundle as not supporting multiuse
< HTTP/1.1 429 Rate limit exceeded
< Transfer-Encoding: chunked
< Date: Thu, 19 May 2022 12:45:01 GMT
< Server: Warp/3.3.20
< 
* Connection #0 to host localhost left intact

And here with servant-errors:

$ curl -XPOST -H "Content-Type: application/json" localhost:8001/ -d '{"msg": "Dmitrii"}' -v
Note: Unnecessary use of -X or --request, POST is already inferred.
*   Trying 127.0.0.1:8001...
* Connected to localhost (127.0.0.1) port 8001 (#0)
> POST / HTTP/1.1
> Host: localhost:8001
> User-Agent: curl/7.81.0
> Accept: */*
> Content-Type: application/json
> Content-Length: 18
> 
* Mark bundle as not supporting multiuse
< HTTP/1.1 429 Rate limit exceeded
< Transfer-Encoding: chunked
< Date: Thu, 19 May 2022 12:45:31 GMT
< Server: Warp/3.3.20
< Content-Type: application/json;charset=utf-8
< 
* Connection #0 to host localhost left intact
{"status":429,"error":"Rate limit exceeded"}

But maybe I'm mistaken here so cc'ing @epicallan for more input!


My minimal code snippet with the error reproduction using servant-rate-limit and servant-errors (adapted from servant-errors README):

#!/usr/bin/env cabal
{- cabal:
build-depends:
  , base               ^>= 4.14.0.0
  , aeson              < 2.0
  , bytestring
  , servant-server
  , servant-errors     ^>= 0.1.7.0
  , servant-rate-limit ^>= 0.2.0.0
  , text
  , time-units-types
  , wai                ^>= 3.2
  , wai-rate-limit     ^>= 0.3
  , warp
-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

module Main where

import           Data.Aeson (FromJSON, ToJSON)
import           Data.ByteString (ByteString)
import           Data.Proxy (Proxy(..))
import           Data.Text (Text)
import           Data.Time.TypeLevel
import           GHC.Generics (Generic)
import           Network.Wai (Application)
import           Network.Wai.Handler.Warp (run)
import           Network.Wai.Middleware.Servant.Errors (errorMw, HasErrorBody(..))
import           Network.Wai.RateLimit.Backend
import           Servant 
import           Servant.RateLimit
import           Servant.RateLimit.Server ()

-- | A greet message data type for use as Request Body
newtype Greet = Greet { msg :: Text }
  deriving (Generic, Show, FromJSON, ToJSON)

type TestApi
    =  RateLimit (SlidingWindow ('Second 10) 1) (IPAddressPolicy "sliding:")
    :> ReqBody '[JSON] Greet 
    :> Post '[JSON] Greet

-- servant application
main :: IO ()
main = do
    putStrLn "Starting the server..."
    run 8001
        $ errorMw @JSON @'["error", "status"]  -- comment this line to check usage without servant-errors
        $ serveWithContext
            api
            (backend :. EmptyContext )
            handler
  where
    handler = return . id
    api = Proxy @TestApi

    -- simple backend to always throw rate limit error
    backend :: Backend ByteString
    backend = MkBackend
        { backendGetUsage       = \_   -> pure 1000
        , backendIncAndGetUsage = \_ _ -> pure 1000
        , backendExpireIn       = \_ _ -> pure ()
        }
arianvp commented 2 years ago

Another option would be to go the route that we wanted to go with servant-auth. (Where we need to solve the exact same issue; but for auth errors instead of ratel imit errors!)

Instead of introducing a new typeclass, we can add an (optional) error handler to Servant's Context and if it's present use it to customize the error:

https://github.com/haskell-servant/servant-auth/pull/168 https://github.com/haskell-servant/servant/issues/1585

main = do
    putStrLn "Starting the server..."
    run 8001
        $ serveWithContext
            api
            (myRateLimitHandler :. backend :. EmptyContext )
            handler
  where
    handler = return . id
    api = Proxy @TestApi
    myRateLimitHandler = RateLimitHandler $ delayedFailFatal err401{}

    -- simple backend to always throw rate limit error
    backend :: Backend ByteString
    backend = MkBackend
        { backendGetUsage       = \_   -> pure 1000
        , backendIncAndGetUsage = \_ _ -> pure 1000
        , backendExpireIn       = \_ _ -> pure ()
        }
mbg commented 2 years ago

@chshersh: Thanks for the follow-up! It's really useful in understanding what your use case is and how servant-errors doesn't (easily) address it. I thought more about the error messages over the weekend and it really feels like a more general problem. I note that the Handler monad is an instance of MonadCatch and MonadThrow, which would allow you to write a helper function that can catch ServerErrors and transform them into ServerErrors which have a response body containing the error message you want. A minimal example to adjust the error message for 429 errors:

withCustomServerError :: MonadCatch m => ByteString -> m a -> m a
withCustomServerError msg action = action `catch` \(err :: ServerError) -> 
    if errHTTPCode /= 429
    then throwM err
    else throwM err{ errBody = msg }

Simply install this around the server handlers for the relevant endpoints in your API. If you are using this in combination with servant-errors, it will then transform the response to JSON for you. It should keep the new errBody as the value for the "error" key.

This should also allow you to adjust the headers using the errHeaders field.

@arianvp: Thank you for your suggestion and the interesting pointer to the relevant discussion for servant-auth. I hadn't thought about using the context for this at all yet. I suppose that would run into the same limitation as with servant-errors though in that sticking an error handler into the context would mean that either there's only one or that the RateLimit combinator needs to be configurable so that it can identify a particular error handler. I am also not sure how easy it would be to not specify any context in cases where no customisation is desired and fall back to some default handler.


I will summarise some thoughts about different points in the design space if the above withCustomServerError solution is insufficient, both so I can remember and in case anyone would like to comment:

  1. The approach @chshersh suggested. It seems like a sensible suggestion. The main downside I see is the need for an extra parameter for the RateLimit (or variant) combinator. This is not a huge trade-off and could be hidden behind a type alias in cases where no customisation is required as you pointed out, but it might still be a breaking change since RateLimit then is no longer a concrete type, but a type alias.
  2. We could move the entire rate limiting configuration into some type-level specification that can be named:
    
    type MyRateLimitingCfg
    = SlidingWindow ('Second 10) 1 
    :<> IPAddressPolicy "sliding:"
    :<> ErrorMessage "You've exceeded rate limit for testing"

type API = RateLimit MyRateLimitingCfg :> ...

Or as a type-level list. This is essentially the approach that e.g. `aeson-deriving` takes, but this feels disingenuous, since we would expect each type of option once and combining them in this way suggests they can be used multiple times. 

3. An error handling configuration in the context, as suggested by @arianvp. This would have to be tagged somehow to establish a relation with a particular `RateLimit` combinator. There is also the question over whether it must always be supplied in the context.
4. Extra combinator(s) that can be included in the API specification
```haskell
type API 
    = IncludeRateLimitHeaders (SlidingWindow ('Second 10) 1) (IPAddressPolicy "sliding:")
   :> ErrorMessageFor 421 "You've exceeded the rate limit for testing" 
   :> RateLimit (SlidingWindow ('Second 10) 1) (IPAddressPolicy "sliding:")

Something like ErrorMessageFor in this example might be out-of-scope for this library though. I would also have to play around with implementing this to see how well something like IncludeRateLimitHeaders works, independent of RateLimit.

  1. We could attach this to the existing rate limiting policies by extending the HasRateLimitPolicy class with the methods from @chshersh's suggested HasRateLimitErrBody class and ensure that it is easy to define new policies. This would mean no changes to RateLimit are required.
chshersh commented 2 years ago

@mbg I appreciate such a detailed response 🤗 Thanks for providing so many alternative solutions to the problem! I will share my thoughts on possible solutions.

The proposed withCustomServerError function doesn't work because the RateLimitedException is not thrown by the handler itself. It's thrown externally by HasServerInstance for RateLimit. So it doesn't work for customizing individual errors.

In other words, performing the following patch on my minimal code example doesn't show My error in the response (tried with and without servant-errors).

    run 8001
        $ errorMw @JSON @'["error", "status"]
        $ serveWithContext
            api
            (backend :. EmptyContext)
-           handler
+           (withCustomServerError "My Error" . handler)
  where
    handler = return . id

I probably can wrap my whole server in a single catch-rethrow statement but this solution has the same drawbacks as servant-errors - no straightforward ability to customize individual messages. Or maybe I'm just doing something wrong, so feel free to correct me 😅

Currently, it looks like the three proposed solutions are not flexible enough as they only allow to catch all the exceptions in one single place globally and not per endpoint:


I agree with your thoughts on all new 5 suggestions 👍🏻 I find options 1 and 5 to be easier to implement and it'll work fine for me.

For option 1, it's possible to implement it without any breaking changes. You can keep RateLimit as data and implement its instance in a single line by reusing RateLimitCustom. I expect something like the following code to work:

  route = route @(RateLimitCustom strategy policy EmptyErrorBody :> api)

So no boilerplate 🙂

For option 5, it looks less convenient to me because I have several endpoints with the same rate limit policy but I might want to have different error messages for them. So if this can only be done via the HasRateLimitPolicy instance, this solution will result into unnecessary boilerplate with the possibility of introducing an error.

mbg commented 2 years ago

Hi again @chshersh. Sorry for the delay in getting back to you and sorry for suggesting something that doesn't actually work. I think your approach is the one to go for then and I would very much welcome a PR which implements it :)