jpadilla / pyjwt

JSON Web Token implementation in Python
https://pyjwt.readthedocs.io
MIT License
5.1k stars 682 forks source link

Fix Issue https://github.com/jpadilla/pyjwt/issues/914 #915

Closed jaferrando closed 8 months ago

jaferrando commented 1 year ago

Fix for https://github.com/jpadilla/pyjwt/issues/914

jaferrando commented 1 year ago

the change will need some more test coverage

The test cases are already there. Specifically https://github.com/jpadilla/pyjwt/blob/95638cf04f83b3e2c289b28810501e53195ff938/tests/test_jwks_client.py#L222.

Actually, that test second call to get_jwk_set should fail with the error mentioned in the issue. I've investigated why it does not, and the root cause seems to be that in fetch_data, the cache is not fed with a PyJWKSet but with the dict directly read from the URL. As the cache is not the expected type but a dict, the get_jwk_set function obtains a dict from the cache, and the test passes.

In my case I was feeding the cache in the client directly from an externally stored JWKS, to avoid the fetch from the URL. To do so, I manually created the PyJWKSetCache object and fed it with a PyJWKSet. That's why in my case the cache contained the expected data type and my code failed with the error.

Heres's the result of my code inspection:

First time the get_jwk_set is called, it will get None from the cache and call fetch_data. This is what in my opinion does not honor the declared data types in the type hints:

    def fetch_data(self) -> Any:
        jwk_set: Any = None
        try:
            r = urllib.request.Request(url=self.uri, headers=self.headers)
            with urllib.request.urlopen(r, timeout=self.timeout) as response:
                jwk_set = json.load(response)
        except (URLError, TimeoutError) as e:
            raise PyJWKClientConnectionError(
                f'Fail to fetch data from the url, err: "{e}"'
            )
        else:
            return jwk_set
        finally:
            if self.jwk_set_cache is not None:
                self.jwk_set_cache.put(jwk_set) --> This puts a JSON Dict in jwk_set

Then the JWKSetCache put method that should receive a PyJWKSet as per the type hint will receive a dictionary:

class JWKSetCache:
    def __init__(self, lifespan: int) -> None:
        self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
        self.lifespan = lifespan

    def put(self, jwk_set: PyJWKSet) -> None:  --> BUT this expects not a Dict but a PyJWKSet (not a class of Dict)
        if jwk_set is not None:
            self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) --> The Dict goes in constructing the PyJWTSetWithTimestamp
        else:
            # clear cache
            self.jwk_set_with_timestamp = None

And that dictionary will find its way through to the final object stored in the cache:

=================           
class PyJWTSetWithTimestamp:
    def __init__(self, jwk_set: PyJWKSet):
        self.jwk_set = jwk_set                 ---> The Dict ends as the value in the cached element
        self.timestamp = time.monotonic()

    def get_jwk_set(self) -> PyJWKSet:
        return self.jwk_set          --> This will NOT return a PyJWKSet but a Dict

    def get_timestamp(self) -> float:
        return self.timestamp

Then, the JWKSetCache class will blindly return the dict instead of the PyJWKSet.

=================
class JWKSetCache:
    def __init__(self, lifespan: int) -> None:
        self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
        self.lifespan = lifespan

    def put(self, jwk_set: PyJWKSet) -> None:
        if jwk_set is not None:
            self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
        else:
            # clear cache
            self.jwk_set_with_timestamp = None

    def get(self) -> Optional[PyJWKSet]:
        if self.jwk_set_with_timestamp is None or self.is_expired():
            return None

        return self.jwk_set_with_timestamp.get_jwk_set() ---> The cache so, returns a Dict not a PyJWKSet

And finally the get_jwk_set always gets a dict wether it takes the data from the URL or from the cache

    def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
        data = None
        if self.jwk_set_cache is not None and not refresh:
            data = self.jwk_set_cache.get()  ---> So here we are getting a Dict, despite this function code is wrong

        if data is None:
            data = self.fetch_data()

        if not isinstance(data, dict):
            raise PyJWKClientError("The JWKS endpoint did not return a JSON object")

        return PyJWKSet.from_dict(data)

Any code expecting to receive the PyJWKSet type from the cache, or any code which as in my case feeds the cache in other way than through a fetch_data call will fail.

The code in this PR fixed the problem declared in issue #914 without impacting any other current use of the module, as both the received and returned types of the function are preserved and no other parts of the module are changed. However, it does no fix use cases that use fetch_data and access the cache directly instead of through the fixed get_jwk_set.

The mentioned test passes with the fixed version of get_jwk_set too. I'll add an assert to the existing tests to find the key in the result from the get_signing_key functions, trying to access it as a PyJWKSet object, which causes a type error as the dict does not contain the keys attribute.

auvipy commented 10 months ago

is it possible to pull from master and fix the build errors/test failures?