lepture / authlib

The ultimate Python library in building OAuth, OpenID Connect clients and servers. JWS,JWE,JWK,JWA,JWT included.
https://authlib.org/
BSD 3-Clause "New" or "Revised" License
4.39k stars 436 forks source link

ask to transform inline function "load_key" to method of OpenIDMixin #610

Open danilovmy opened 6 months ago

danilovmy commented 6 months ago

Hello, to make existing code more usable, I propose reimagining the OpenIDMixin.parse_id_token method. This method contains an inline function definition def load_key(header, _), and I don't see any reason why this function is not a method of the OpenIDMixin. Moreover, this function uses other self.methods, which marks they as a method of OpenIDMixin rather than a standalone function. Lastly, if it were a method, it could be easily tested and used for other purposes. For example, at present, there's a minor bug in functionality as it does not raise an error if new keys are loaded but still do not contain the desired 'kid'.

Before:

#  authlib\integrations\base_client\sync_openid.py
class OpenIDMixin(object):
    ...

    def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
        """Return an instance of UserInfo from token's ``id_token``."""
        if 'id_token' not in token:
            return None

        def load_key(header, _):
            jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
            try:
                return jwk_set.find_by_kid(header.get('kid'))
            except ValueError:
                # re-try with new jwk set
                jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
                return jwk_set.find_by_kid(header.get('kid'))
      ...

After suggested refactoring:

#  authlib\integrations\base_client\sync_openid.py
class OpenIDMixin(object):
    ...
    def load_key(self, header, force=False):
        jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
        try:
            return jwk_set.find_by_kid(header.get('kid'))
        except ValueError:
            if not force:  # re-try with new jwk set
                return self.load_key(header, force=True)
            raise RuntimeError('Missing "kid" in "jwk_set"')

    def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
        """Return an instance of UserInfo from token's ``id_token``."""
        if 'id_token' not in token:
            return None
      ...
      claims = _jwt.decode(
            token['id_token'], key=self.load_key,
            ...
        )
      ...