jazzband / djangorestframework-simplejwt

A JSON Web Token authentication plugin for the Django REST Framework.
https://django-rest-framework-simplejwt.readthedocs.io/
MIT License
3.98k stars 661 forks source link

Allow other header claims in tokens #531

Open 73VW opened 2 years ago

73VW commented 2 years ago

As defined in RFC7515, section 4.1, tokens can include several more header claims than just typ and alg as allowed from this.

I have tried to include a kid one as I use signed token but I couldn't.

Using pyjwt I was able to add it to the token string but when I called RefreshToken(token) constructor it removed all custom headers.

I have checked in the doc and nothing seems to cover this use case.

I haven't digged much in the code though.

As for kid claim, I suggest to include it by default in header when the token is signed.

(AuthLib documentation for reference)

This is somehow related to #491 as kid might be useful when combined with JWK endpoint.

73VW commented 2 years ago

Seems that the first part of my issue can be done using what has been done in !517

Sadly I couldn't find any issue related to this.

Any clue when this will be on Pypi?

73VW commented 2 years ago

Well after digging into the code I have managed to include the kid header claim as I wanted without using what's in !517.

I've had to redefine quite a few classes.

views.py

import jwt
import rest_framework_simplejwt.views as original_views
from authlib.jose import JsonWebKey
from django.conf import settings
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.serializers import (TokenObtainPairSerializer,
                                                  TokenRefreshSerializer)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken, RefreshToken, Token

class TokenBackendWithHeaders(TokenBackend):

    def encode(self, payload, headers={}):
        """
        Returns an encoded token for the given payload dictionary.
        """
        jwt_payload = payload.copy()
        if self.audience is not None:
            jwt_payload["aud"] = self.audience
        if self.issuer is not None:
            jwt_payload["iss"] = self.issuer

        token = jwt.encode(jwt_payload, self.signing_key,
                           algorithm=self.algorithm, headers=headers)
        if isinstance(token, bytes):
            # For PyJWT <= 1.7.1
            return token.decode("utf-8")
        # For PyJWT >= 2.0.0a1
        return token

class TokenWithAnotherTokenBackend(Token):
    _token_backend = TokenBackendWithHeaders(
        api_settings.ALGORITHM,
        api_settings.SIGNING_KEY,
        api_settings.VERIFYING_KEY,
        api_settings.AUDIENCE,
        api_settings.ISSUER,
        api_settings.JWK_URL,
        api_settings.LEEWAY,
    )

    def __init__(self, token=None, verify=True):
        Token.__init__(self, token, verify)
        self.headers = {}

    def __str__(self):
        """
        Signs and returns a token as a base64 encoded string.
        """
        return self.get_token_backend().encode(self.payload, self.headers)

class AccessTokenWithAnotherTokenBackend(AccessToken, TokenWithAnotherTokenBackend):
    pass

class RefreshTokenWithAnotherTokenBackend(RefreshToken, TokenWithAnotherTokenBackend):

    @property
    def access_token(self):
        """
        Returns an access token created from this refresh token.  Copies all
        claims present in this refresh token to the new access token except
        those claims listed in the `no_copy_claims` attribute.
        """
        access = AccessTokenWithAnotherTokenBackend()

        # Use instantiation time of refresh token as relative timestamp for
        # access token "exp" claim.  This ensures that both a refresh and
        # access token expire relative to the same time if they are created as
        # a pair.
        access.set_exp(from_time=self.current_time)

        no_copy = self.no_copy_claims
        for claim, value in self.payload.items():
            if claim in no_copy:
                continue
            access[claim] = value

        for claim, value in self.headers.items():
            access.headers[claim] = value

        return access

class TokenObtainPairSerializerDifferentToken(TokenObtainPairSerializer):
    token_class = RefreshTokenWithAnotherTokenBackend

    @classmethod
    def get_token(cls, user):

        key = JsonWebKey.import_key(
            settings.SIMPLE_JWT['VERIFYING_KEY'], {'kty': 'RSA'})
        token = cls.token_class.for_user(user)

        # Add custom header claims
        token.headers['kid'] = key.thumbprint()

        return token

class TokenRefreshSerializerDifferentToken(TokenRefreshSerializer):

    # Needed to redifine all of this due to the hardcoded "RefreshToken" in
    # the original code. Replaced here by "RefreshTokenWithAnotherTokenBackend"
    # PR for fixing this was already merged. New version of simple-jwt should
    # include changes contained in
    # https://github.com/jazzband/djangorestframework-simplejwt/pull/517
    def validate(self, attrs):
        refresh = RefreshTokenWithAnotherTokenBackend(attrs['refresh'])

        data = {'access': str(refresh.access_token)}

        if api_settings.ROTATE_REFRESH_TOKENS:
            if api_settings.BLACKLIST_AFTER_ROTATION:
                try:
                    # Attempt to blacklist the given refresh token
                    refresh.blacklist()
                except AttributeError:
                    # If blacklist app not installed, `blacklist` method will
                    # not be present
                    pass

            refresh.set_jti()
            refresh.set_exp()
            refresh.set_iat()

            data['refresh'] = str(refresh)

        return data

class TokenObtainPairView(original_views.TokenObtainPairView):
    serializer_class = TokenObtainPairSerializerDifferentToken

class TokenRefreshView(original_views.TokenRefreshView):
    serializer_class = TokenRefreshSerializerDifferentToken

urls.py

"""e_abeilles URL Configuration

The `urlpatterns` list routes URLs to views. For more information please see:
    https://docs.djangoproject.com/en/4.0/topics/http/urls/
Examples:
Function views
    1. Add an import:  from my_app import views
    2. Add a URL to urlpatterns:  path('', views.home, name='home')
Class-based views
    1. Add an import:  from other_app.views import Home
    2. Add a URL to urlpatterns:  path('', Home.as_view(), name='home')
Including another URLconf
    1. Import the include() function: from django.urls import include, path
    2. Add a URL to urlpatterns:  path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import path, include

from rest_framework_simplejwt import views as jwt_views
from my_package import views

urlpatterns = [
    path('admin/', admin.site.urls),
    path('api/token/', views.TokenObtainPairView.as_view(),
         name='token_obtain_pair'),
    path('api/token/refresh/', views.TokenRefreshView.as_view(),
         name='token_refresh'),
]
73VW commented 2 years ago

Could this be included in the base code? I can open a PR if you wish!

Andrew-Chen-Wang commented 2 years ago

What we’ve done in the past is have a callable or a dotted import string in SIMPLE_JWT settings. In the serializer, we can pass the token to your function. This is similar to the authorization callable.

73VW commented 2 years ago

@Andrew-Chen-Wang That might be possible but I don't think this is the way to go as it involves encoding a token -> sending it to the callback -> decoding it -> adding a header while reencoding it -> sending it back.

Performancewise, adding a header before encoding it would be much better, don't you think?

Andrew-Chen-Wang commented 2 years ago

Yes, it definitely would be. I just worry about the ordering and people missing something with override classes. But please open a PR and we shall deliberate :)

rj76 commented 2 years ago

For anyone interested, here is the same for sliding tokens

import jwt
import rest_framework_simplejwt.views as original_views
from authlib.jose import JsonWebKey
from django.conf import settings
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.serializers import TokenObtainSlidingSerializer
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import SlidingToken, Token

class TokenBackendWithHeaders(TokenBackend):

    def encode(self, payload, headers={}):
        """
        Returns an encoded token for the given payload dictionary.
        """
        jwt_payload = payload.copy()
        if self.audience is not None:
            jwt_payload["aud"] = self.audience
        if self.issuer is not None:
            jwt_payload["iss"] = self.issuer

        token = jwt.encode(jwt_payload, self.signing_key,
                           algorithm=self.algorithm, headers=headers)
        if isinstance(token, bytes):
            # For PyJWT <= 1.7.1
            return token.decode("utf-8")
        # For PyJWT >= 2.0.0a1
        return token

class TokenWithAnotherTokenBackend(Token):
    _token_backend = TokenBackendWithHeaders(
        api_settings.ALGORITHM,
        api_settings.SIGNING_KEY,
        api_settings.VERIFYING_KEY,
        api_settings.AUDIENCE,
        api_settings.ISSUER,
        api_settings.JWK_URL,
        api_settings.LEEWAY,
    )

    def __init__(self, token=None, verify=True):
        Token.__init__(self, token, verify)
        self.headers = {}

    def __str__(self):
        """
        Signs and returns a token as a base64 encoded string.
        """
        return self.get_token_backend().encode(self.payload, self.headers)

class SlidingokenWithAnotherTokenBackend(SlidingToken, TokenWithAnotherTokenBackend):
    pass

class TokenObtainSlidingSerializerDifferentToken(TokenObtainSlidingSerializer):
    token_class = SlidingokenWithAnotherTokenBackend

    @classmethod
    def get_token(cls, user):

        key = JsonWebKey.import_key(
            settings.SIMPLE_JWT['VERIFYING_KEY'], {'kty': 'RSA'})
        token = cls.token_class.for_user(user)

        # Add custom header claims
        token.headers['kid'] = key.thumbprint()

        return token

class TokenObtainSlidingView(original_views.TokenObtainPairView):
    serializer_class = TokenObtainSlidingSerializerDifferentToken
nixsiow commented 1 year ago

Has this been incorporated or solved in the latest codebase as I am currently facing the exact same issue of trying to add a 'kid' claim into the header of the signed token? So strange that this is not mentioned anywhere in the docs.

Andrew-Chen-Wang commented 1 year ago

This is not implemented.

steven-jeanneret commented 1 year ago

Would a new settings EXTRA_JWT_HEADERS be a solution?

I'm facing this problem where I want to add kid in the headers.

adamJLev commented 1 month ago

At this point it would be good to just have the kid header always added by default no? Its part of the JWK standard now https://datatracker.ietf.org/doc/html/rfc7517#section-4.5