trinodb / trino-python-client

Python client for Trino
Apache License 2.0
307 stars 150 forks source link

Make JWTAuthentication accept token providers to support token refresh #466

Open sugibuchi opened 2 weeks ago

sugibuchi commented 2 weeks ago

Describe the feature

Extend JWTAuthentication to support both static JWT tokens and JWT tokens dynamically produced by given Callable objects.

This extension aims to refresh JWT tokens in JWTAuthentication.

Context

We are currently trying to use JWT access tokens issued by Azure Entra ID (aka. Azure Active Directory) to authenticate clients (more precisely, DBT workflows running in Azure) accessing our Trino.

Retrieving access tokens from Entra ID is a straightforward process. We can accomplish this by using the TokenCredential implementations provided by the azure-identity package. For instance,

### "api://..." is the application URI of our Trino registered in Entra ID -> will appear as "aud" in JWT tokens.
scope = "api://xxxxxxxxx-xxxx-xxxx-xxxxx-xxxxxxxxxxxx/.default"

credential = DefaultAzureCredential()
jwt_token = credential.get_token(scope).token

conn = connect(
    auth=JWTAuthentication(jwt_token)
    http_scheme="https",
    ...
)

A problem here is that there is no way to refresh tokens set to JWTAuthentication as the current version of JWTAuthentication accepts only a static token, and sets the token to requests.Session objects as a static auth header value.

For example, JWT tokens issued by Entra ID usually expire within 1 hour or less. Therefore, an application using JWTAuthentication will fail at a certain moment due to the expiration of access tokens unless the application frequently recreates the JWTAuthentication instance with a new token.

Proposal

Extend JWTAuthentication to accept both a static JWT token and a Callable object as the init argument.

https://github.com/trinodb/trino-python-client/blob/0.328.0/trino/auth.py#L146-L153

class JWTAuthentication(Authentication):

    def __init__(self, token: Union[str, Callable[[], str]]):
        if isinstance(token, str):
            self.token_provider = lambda : token
        else:
            self.token_provider = token

    def set_http_session(self, http_session: Session) -> Session:
        http_session.auth = _BearerAuth(self.token_provider)
        return http_session

We also need to extend _BearerAuth. https://github.com/trinodb/trino-python-client/blob/0.328.0/trino/auth.py#L142

        r.headers["Authorization"] = "Bearer " + self.token_provider()

With this extension, we can rewrite the sample code above as follows:

conn = connect(
    auth=JWTAuthentication(lambda: credential.get_token(scope).token)
    http_scheme="https",
    ...
)

This code won't fail as the credential (DefaultAzureCredential) will automatically cache and refresh access tokens.

Describe alternatives you've considered

We have implemented a custom JWTAuthentication with this extension by ourselves. However, JWT token refresh is a common concern that can appear in various use case scenarios. It would be nice to have this extension as a part of the Trino Python client.

Are you willing to submit PR?

hashhar commented 1 week ago

As far as I understand even if the token is refreshed any already running queries would still fail? If so we should make it clear once we add the functionality you propose.

Also since you seem to have an idea of how to implement this already would you like to send a PR? The proposal looks good so far.

hashhar commented 5 days ago

And also once we have this implemented we can think of changing https://github.com/trinodb/trino-python-client/pull/462 to align with whatever "interface" we come up with for this issue.

hashhar commented 5 days ago

I just noticed I forgot to tag you @sugibuchi. Wanted to check if you're still willing to send a PR for this since it seems you have some ideas already.

sugibuchi commented 3 days ago

@hashhar

Thank you very much for your comment.

I checked the source code of the Trino server. JWT tokens received by the Trino server are parsed and verified (including an expiration check) in JwtAuthenticator. However, claims in the parsed JWT tokens, including expiration time, are not exposed outside JwtAuthenticator except for user identities.

https://github.com/trinodb/trino/blob/450/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java#L66-L75

Therefore,

As far as I understand even if the token is refreshed any already running queries would still fail?

This cannot happen as the expiration time written in JWT tokens is not used after the authentication in JwtAuthenticator. Already running queries are supposed to continue after the token expires.

Also since you seem to have an idea of how to implement this already would you like to send a PR?

Yes. I can prepare a PR for this change.