Closed jaferrando closed 8 months ago
the change will need some more test coverage
The test cases are already there. Specifically https://github.com/jpadilla/pyjwt/blob/95638cf04f83b3e2c289b28810501e53195ff938/tests/test_jwks_client.py#L222.
Actually, that test second call to get_jwk_set should fail with the error mentioned in the issue. I've investigated why it does not, and the root cause seems to be that in fetch_data, the cache is not fed with a PyJWKSet but with the dict directly read from the URL. As the cache is not the expected type but a dict, the get_jwk_set function obtains a dict from the cache, and the test passes.
In my case I was feeding the cache in the client directly from an externally stored JWKS, to avoid the fetch from the URL. To do so, I manually created the PyJWKSetCache object and fed it with a PyJWKSet. That's why in my case the cache contained the expected data type and my code failed with the error.
Heres's the result of my code inspection:
First time the get_jwk_set is called, it will get None from the cache and call fetch_data. This is what in my opinion does not honor the declared data types in the type hints:
def fetch_data(self) -> Any:
jwk_set: Any = None
try:
r = urllib.request.Request(url=self.uri, headers=self.headers)
with urllib.request.urlopen(r, timeout=self.timeout) as response:
jwk_set = json.load(response)
except (URLError, TimeoutError) as e:
raise PyJWKClientConnectionError(
f'Fail to fetch data from the url, err: "{e}"'
)
else:
return jwk_set
finally:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set) --> This puts a JSON Dict in jwk_set
Then the JWKSetCache put method that should receive a PyJWKSet as per the type hint will receive a dictionary:
class JWKSetCache:
def __init__(self, lifespan: int) -> None:
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan
def put(self, jwk_set: PyJWKSet) -> None: --> BUT this expects not a Dict but a PyJWKSet (not a class of Dict)
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) --> The Dict goes in constructing the PyJWTSetWithTimestamp
else:
# clear cache
self.jwk_set_with_timestamp = None
And that dictionary will find its way through to the final object stored in the cache:
=================
class PyJWTSetWithTimestamp:
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set ---> The Dict ends as the value in the cached element
self.timestamp = time.monotonic()
def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set --> This will NOT return a PyJWKSet but a Dict
def get_timestamp(self) -> float:
return self.timestamp
Then, the JWKSetCache class will blindly return the dict instead of the PyJWKSet.
=================
class JWKSetCache:
def __init__(self, lifespan: int) -> None:
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan
def put(self, jwk_set: PyJWKSet) -> None:
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None
def get(self) -> Optional[PyJWKSet]:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None
return self.jwk_set_with_timestamp.get_jwk_set() ---> The cache so, returns a Dict not a PyJWKSet
And finally the get_jwk_set always gets a dict wether it takes the data from the URL or from the cache
def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get() ---> So here we are getting a Dict, despite this function code is wrong
if data is None:
data = self.fetch_data()
if not isinstance(data, dict):
raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
return PyJWKSet.from_dict(data)
Any code expecting to receive the PyJWKSet type from the cache, or any code which as in my case feeds the cache in other way than through a fetch_data call will fail.
The code in this PR fixed the problem declared in issue #914 without impacting any other current use of the module, as both the received and returned types of the function are preserved and no other parts of the module are changed. However, it does no fix use cases that use fetch_data and access the cache directly instead of through the fixed get_jwk_set.
The mentioned test passes with the fixed version of get_jwk_set too. I'll add an assert to the existing tests to find the key in the result from the get_signing_key functions, trying to access it as a PyJWKSet object, which causes a type error as the dict does not contain the keys
attribute.
is it possible to pull from master and fix the build errors/test failures?
Fix for https://github.com/jpadilla/pyjwt/issues/914