haskell-servant / servant-auth

160 stars 73 forks source link

combinator that forces authentication to have succeeded #149

Open cdepillabout opened 5 years ago

cdepillabout commented 5 years ago

It would be nice to have a combinator that forces authentication to have succeeded.

The ServerT type family instance for Auth looks like the following:

type ServerT (Auth auths v :> api) m = AuthResult v -> ServerT api m

AuthResult may be a failure, so in all your endpoints, you have to handle this potential failure to authenticate. The following example has been taken from the README:

protected :: AuthResult User -> Server Protected
protected (Authenticated user) = return (name user) :<|> return (email user)
protected _ = throwAll err401

It would be nice to have a combinator where you can be sure that authentication has succeeded. I imagine the ServerT type family instance would look like the following:

type ServerT (Auth' auths v :> api) m = v -> ServerT api m

I wrote up a small example of actually implementing this. The only "tricky" thing I am doing here is creating a type class that the end user can optionally use to define how the ServantErr is created. This would let them do something like redirect to a login page instead of just throwing a HTTP 401 error.

I'd be happy to send a PR adding this if something like this would be accepted into the library. I can clean up the comments / code if need be.

{-# LANGUAGE UndecidableInstances #-}

module Foobar where

import Prelude

import Control.Monad.IO.Class (liftIO)
import Data.Proxy (Proxy(Proxy))
import Servant (HasLink(toLink), MkLink, (:>))
import Servant.Auth.Server.Internal.AddSetCookie (AddSetCookies,
                                                  AddSetCookiesApi, Nat (S, Z),
                                                  SetCookieList (SetCookieCons, SetCookieNil),
                                                  addSetCookies)
import Servant.Auth.Server.Internal.Class (AreAuths, runAuths)
import Servant.Auth.Server.Internal.ConfigTypes (CookieSettings, JWTSettings)
import Servant.Auth.Server.Internal.Cookie (makeSessionCookie, makeXsrfCookie)
import Servant.Auth.Server.Internal.JWT (ToJWT)
import Servant.Auth.Server.Internal.Types (AuthResult(Authenticated), runAuthCheck)
import Servant.Server.Internal (Context, Delayed, DelayedIO, Handler,
                                HasContextEntry, HasServer, Router, ServantErr,
                                ServerT, addAuthCheck, delayedFail, err401,
                                getContextEntry, hoistServerWithContext, route,
                                withRequest)

-- | This is similar to 'Auth', but it forces the authentication to have
-- succeeded.
--
-- This is used when you want Servant to automatically throw an error when
-- authentication did not succeed.
--
-- The @error@ parameter is used with the 'HasAuthFailedErrHandler'.  If you use
-- 'AuthReq', Servant will throw an HTTP 401 (unauthorized) error.
--
-- If you create a new datatype and define an instance of 'HasAuthFailedErrHandler',
-- you can define how servant constructs the 'ServantErr' when authentication fails.
data AuthReq' error (auths :: [*]) val

-- | This is a datatype to use to define an instance of
-- 'HasAuthFailedErrHandler'.  It is used by 'AuthReq'.
data AuthErr401

-- | A type alias for easily using the 'HasAuthFailedErrHandler' for 'AuthErr401'
type AuthReq = AuthReq' AuthErr401

-- | This class is used to define how a 'ServantErr' is generated when authentication fails.
--
-- The default definition for 'authFailedServerErr' is 'err401.
class HasAuthFailedErrHandler error where
  authFailedServerErr :: proxy error -> ServantErr
  authFailedServerErr _ = err401

instance HasAuthFailedErrHandler AuthErr401

instance HasLink sub => HasLink (AuthReq' error (tag :: [*]) value :> sub) where
  type MkLink (AuthReq' error (tag :: [*]) value :> sub) r = MkLink sub r
  toLink toA _ = toLink toA (Proxy @sub)

instance ( n ~ 'S ('S 'Z)
         , HasServer (AddSetCookiesApi n api) ctxs, AreAuths auths ctxs v
         , HasServer api ctxs -- this constraint is needed to implement hoistServer
         , AddSetCookies n (ServerT api Handler) (ServerT (AddSetCookiesApi n api) Handler)
         , ToJWT v
         , HasContextEntry ctxs CookieSettings
         , HasContextEntry ctxs JWTSettings
         , HasAuthFailedErrHandler error
         ) => HasServer (AuthReq' error auths v :> api) ctxs where
  type ServerT (AuthReq' error auths v :> api) m = v -> ServerT api m

  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s

  route
    :: forall env
     . Proxy (AuthReq' error auths v :> api)
    -> Context ctxs
    -> Delayed env (v -> ServerT api Handler)
    -> Router env
  route _ context subserver =
    route (Proxy :: Proxy (AddSetCookiesApi n api))
          context
          (fmap go subserver `addAuthCheck` authCheck)
    where
      authCheck :: DelayedIO (v, SetCookieList ('S ('S 'Z)))
      authCheck = withRequest $ \req -> do
        authResult <- liftIO $ (runAuthCheck (runAuths (Proxy :: Proxy auths) context) req :: IO (AuthResult v))
        case authResult of
          Authenticated v -> liftIO $ do
            cookies <- makeCookies v
            pure (v, cookies)
          _ -> delayedFail $ authFailedServerErr (Proxy @error)

      jwtSettings :: JWTSettings
      jwtSettings = getContextEntry context

      cookieSettings :: CookieSettings
      cookieSettings = getContextEntry context

      makeCookies :: v -> IO (SetCookieList ('S ('S 'Z)))
      makeCookies v = do
        xsrf <- makeXsrfCookie cookieSettings
        ejwt <- makeSessionCookie cookieSettings jwtSettings v
        fmap (Just xsrf `SetCookieCons`) $ do
          case ejwt of
            Nothing  -> return $ Nothing `SetCookieCons` SetCookieNil
            Just jwt -> return $ Just jwt `SetCookieCons` SetCookieNil

      go :: (v -> ServerT api Handler)
         -> (v, SetCookieList n)
         -> ServerT (AddSetCookiesApi n api) Handler
      go fn (authResult, cookies) = addSetCookies cookies $ fn authResult
cdepillabout commented 5 years ago

Here's a slightly more general version. It modifies HasAuthFailedErrHandler to run in IO and have access to a given Context:

{-# LANGUAGE UndecidableInstances #-}

module Mokusei.Servant where

import Prelude

import Control.Monad.IO.Class (liftIO)
import Data.Proxy (Proxy(Proxy))
import Servant (HasLink(toLink), MkLink, (:>))
import Servant.Auth.Server.Internal.AddSetCookie (AddSetCookies,
                                                  AddSetCookiesApi, Nat (S, Z),
                                                  SetCookieList (SetCookieCons, SetCookieNil),
                                                  addSetCookies)
import Servant.Auth.Server.Internal.Class (AreAuths, runAuths)
import Servant.Auth.Server.Internal.ConfigTypes (CookieSettings, JWTSettings)
import Servant.Auth.Server.Internal.Cookie (makeSessionCookie, makeXsrfCookie)
import Servant.Auth.Server.Internal.JWT (ToJWT)
import Servant.Auth.Server.Internal.Types (AuthResult(Authenticated), runAuthCheck)
import Servant.Server.Internal (Context, Delayed, DelayedIO, Handler,
                                HasContextEntry, HasServer, Router, ServantErr,
                                ServerT, addAuthCheck, delayedFail, err401,
                                getContextEntry, hoistServerWithContext, route,
                                withRequest)

-- | This is similar to 'Auth', but it forces the authentication to have
-- succeeded.
--
-- This is used when you want Servant to automatically throw an error when
-- authentication did not succeed.
--
-- The @error@ parameter is used with the 'HasAuthFailedErrHandler'.  If you use
-- 'AuthReq', Servant will throw an HTTP 401 (unauthorized) error.
--
-- If you create a new datatype and define an instance of 'HasAuthFailedErrHandler',
-- you can define how servant constructs the 'ServantErr' when authentication fails.
data AuthReq' error (auths :: [*]) val

-- | This is a datatype to use to define an instance of
-- 'HasAuthFailedErrHandler'.  It is used by 'AuthReq'.
data AuthErr401

-- | A type alias for easily using the 'HasAuthFailedErrHandler' for 'AuthErr401'
type AuthReq = AuthReq' AuthErr401

-- | This class is used to define how a 'ServantErr' is generated when authentication fails.
--
-- The default definition for 'authFailedServerErr' is 'err401.
class HasAuthFailedErrHandler ctx error | error -> ctx where
  authFailedServerErr :: ctx -> proxy error -> IO ServantErr
  authFailedServerErr _ _ = pure err401

instance HasAuthFailedErrHandler () AuthErr401

instance HasLink sub => HasLink (AuthReq' error (tag :: [*]) value :> sub) where
  type MkLink (AuthReq' error (tag :: [*]) value :> sub) r = MkLink sub r
  toLink toA _ = toLink toA (Proxy @sub)

instance ( n ~ 'S ('S 'Z)
         , HasServer (AddSetCookiesApi n api) ctxs, AreAuths auths ctxs v
         , HasServer api ctxs -- this constraint is needed to implement hoistServer
         , AddSetCookies n (ServerT api Handler) (ServerT (AddSetCookiesApi n api) Handler)
         , ToJWT v
         , HasContextEntry ctxs CookieSettings
         , HasContextEntry ctxs JWTSettings
         , HasContextEntry ctxs errHandlerCtx
         , HasAuthFailedErrHandler errHandlerCtx error
         ) => HasServer (AuthReq' error auths v :> api) ctxs where
  type ServerT (AuthReq' error auths v :> api) m = v -> ServerT api m

  hoistServerWithContext
    :: forall (m :: * -> *) (n1 :: * -> *)
     . Proxy (AuthReq' error auths v :> api)
    -> Proxy ctxs
    -> (forall x. m x -> n1 x)
    -> (v -> ServerT api m)
    -> (v -> ServerT api n1)
  hoistServerWithContext _ pc nt s = hoistServerWithContext (Proxy :: Proxy api) pc nt . s

  route
    :: forall env
     . Proxy (AuthReq' error auths v :> api)
    -> Context ctxs
    -> Delayed env (v -> ServerT api Handler)
    -> Router env
  route _ context subserver =
    route (Proxy :: Proxy (AddSetCookiesApi n api))
          context
          (fmap go subserver `addAuthCheck` authCheck)
    where
      authCheck :: DelayedIO (v, SetCookieList ('S ('S 'Z)))
      authCheck = withRequest $ \req -> do
        authResult <- liftIO $ (runAuthCheck (runAuths (Proxy :: Proxy auths) context) req :: IO (AuthResult v))
        case authResult of
          Authenticated v -> liftIO $ do
            cookies <- makeCookies v
            pure (v, cookies)
          _ -> do
            res <- liftIO $ authFailedServerErr (getContextEntry context :: errHandlerCtx) (Proxy @error)
            delayedFail res

      jwtSettings :: JWTSettings
      jwtSettings = getContextEntry context

      cookieSettings :: CookieSettings
      cookieSettings = getContextEntry context

      makeCookies :: v -> IO (SetCookieList ('S ('S 'Z)))
      makeCookies v = do
        xsrf <- makeXsrfCookie cookieSettings
        ejwt <- makeSessionCookie cookieSettings jwtSettings v
        fmap (Just xsrf `SetCookieCons`) $ do
          case ejwt of
            Nothing  -> return $ Nothing `SetCookieCons` SetCookieNil
            Just jwt -> return $ Just jwt `SetCookieCons` SetCookieNil

      go :: (v -> ServerT api Handler)
         -> (v, SetCookieList n)
         -> ServerT (AddSetCookiesApi n api) Handler
      go fn (authResult, cookies) = addSetCookies cookies $ fn authResult

The big downside is that this requires the user to specify an additional Context when they run their server.