dbt-labs / dbt-snowflake

dbt-snowflake contains all of the code enabling dbt to work with Snowflake
https://getdbt.com
Apache License 2.0
297 stars 177 forks source link

[Feature] Improve key-pair auth performance #1082

Closed colin-rogers-dbt closed 4 months ago

colin-rogers-dbt commented 5 months ago

Is this your first time submitting a feature request?

Describe the feature

@VersusFacit's analysis:

Params:
Unix time command -- acceptable imprecision for the orders of magnitude we're dealing with here
5000 node project
average 2 runs each for each authentication method
dbt snowflake - user pass: 427.83s user 42.43s system 16% cpu 46:49.12 total
dbt snowflake - key pair: 1011.76s user 44.96s system 32% cpu 53:42.99 total

400 vs 1000 is quite a dramatic difference!

Avenues for investigation:

VersusFacit commented 5 months ago

Useful docs I used for testing: https://docs.snowflake.com/en/user-guide/key-pair-auth

openssl genrsa 2048 | openssl pkcs8 -topk8 -inform PEM -out rsa_key.p8 -nocrypt
openssl rsa -in rsa_key.p8 -pubout -out rsa_key.pub
ALTER USER <you> SET RSA_PUBLIC_KEY=''; -- yes this will require admin help; Don't include header/footer strings
amardatar commented 4 months ago

Hey team! I found this issue after switching from a 2048-bit key-pair to a 4096-bit one, and found my dbt run times increasing from ~5 minutes to ~15 minutes.

I had a bit of a dig on this, and figured I'd share some findings (and can put together some suggestions for changes as well if there's a preference on how this is handled in the project).

The core issue seems to be that the private key is being read (and validated - I'll get to that) on every dbt node, which eats up the time.

First - this took me too long to find, but the easiest solution seems to just be using the reuse_connections profile config. Maybe it could be a suggestion in the key-pair authentication section of the docs to use that config?

Anyway - in terms of testing what was going on, I did a few tests using the Snowflake Connector for Python and found that (with a fixed private key) execution times were virtually the same across either password, 2048-bit key, or 4096-bit key options.

I had a look at dbt-snowflake and found the above, i.e. that the private key was being read each on each node. Adding a bit of caching somewhat resolved the issue and substantially reduced run times.

I was a bit surprised by this, so I decided to check how long loading keys actually took. My test script looked like this:

def benchmark_key_import(key: str, unsafe_skip_rsa_key_validation: bool = False, n_tries: int = 1000):
    start = time.time()
    for _ in range(0, n_tries):
        private_key = serialization.load_pem_private_key(
            data=bytes(key, 'utf-8'),
            password=None,
            unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation
        )
        key_bytes = private_key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )
    end = time.time()
    print(end - start)

Some results from that:

That validation is pretty substantial, which is why I didn't want to immediately put a PR together.

The cryptography docs don't provide much detail on exactly what's unsafe about skipping validation, and I don't know nearly enough about the security elements to say for sure. However, the Snowflake Connector for Python is also using cryptography, and either requires bytes (which it reads with validation) or an instance of RSAPrivateKey (which would already be validated). Essentially, this means that dbt-snowflake can (and perhaps should) skip validation since it's already being done later by the Snowflake Connector and there's no value in doing it twice.

Caching would of course help as well; I imagine there aren't any cases where data changes during an execution such that a cached result became invalid (and if it did, then it could probably be stored in a dictionary instead to maintain it) but my sense based on the above is that skipping validation would be a more sensible solution and would effectively invalidate the need for caching.

Beyond that, and as mentioned at the top, I think enabling the reuse_connections config is the ideal option, since it also means the re-validation that is happening in the Snowflake Connector is also skipped. Enabling this config did result in the shortest run-times in my testing (and run-times that were largely equal across password/private key auth methods). This might be academic, but I'd be interested to know if there's any particular reason for this being disabled by default, and if there's any telemetry on how often it's enabled?

mikealfare commented 4 months ago

Thanks for the thorough analysis @amardatar! I agree that updating the default behavior of reuse_connections looks like the correct approach. It also improves performance for user/password auth as well. I made that update and translated your performance testing into a dbt-snowflake test.

aranke commented 4 months ago

Hi @amardatar, thanks for the thorough analysis; it is much appreciated!

The Problem

What you've stumbled into looks like a known issue in cryptography 😞 (link):

The cause of this is almost certainly our upgrade to OpenSSL 3.0, which has a new algorithm for checking the validity of RSA private keys, which is known to be much slower for valid RSA keys.

Unfortunately we're in between a rock and a hard place and don't have a way to avoid this cost without silently accepting invalid keys which can have all sorts of negative consequences.

I don't have a better suggestion than "try to avoid loading the same key multiple times".

The description from the PR that implemented the unsafe_skip_rsa_key_validation has this note, which doesn't inspire confidence.

This is a significant performance improvement but is only safe if you know the key is valid. If you use this when the key is invalid OpenSSL makes no guarantees about what might happen. Infinite loops, crashes, and all manner of terrible things become possible if that occurs. Beware, beware, beware.

Given that this keypair is user-configurable from YAML, I don't feel comfortable bypassing this check.

A Potential Solution

Caching would of course help as well; I imagine there aren't any cases where data changes during an execution such that a cached result became invalid (and if it did, then it could probably be stored in a dictionary instead to maintain it).

This is a great observation, so I took the liberty of modifying your script into a reprex to investigate the performance impact.

# benchmark.py
import time
from cryptography.hazmat.primitives import serialization
from functools import cache

@cache
def cached_load_pem_private_key(data, password, unsafe_skip_rsa_key_validation):
    return serialization.load_pem_private_key(
        data=data,
        password=password,
        unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
    )

def benchmark_key_import(
    key: str,
    unsafe_skip_rsa_key_validation: bool = False,
    n_tries: int = 1000,
    method=serialization.load_pem_private_key,
):
    print(
        f"unsafe_skip_rsa_key_validation={unsafe_skip_rsa_key_validation}, n_tries={n_tries}, method={method}"
    )
    cached_load_pem_private_key.cache_clear()

    start = time.time()
    for _ in range(0, n_tries):
        private_key = method(
            data=bytes(key, "utf-8"),
            password=None,
            unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
        )
        key_bytes = private_key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption(),
        )
    end = time.time()

    print(end - start)
    print(cached_load_pem_private_key.cache_info())
    print()

with open("key.pem", "r") as f:
    key = f.read()

    benchmark_key_import(key)
    benchmark_key_import(key, unsafe_skip_rsa_key_validation=True)

    benchmark_key_import(key, method=cached_load_pem_private_key)
    benchmark_key_import(
        key, method=cached_load_pem_private_key, unsafe_skip_rsa_key_validation=True
    )
❯ openssl genrsa -out key.pem 4096
❯ python benchmark.py
unsafe_skip_rsa_key_validation=False, n_tries=1000, method=<built-in function load_pem_private_key>
345.70546793937683
CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)

unsafe_skip_rsa_key_validation=True, n_tries=1000, method=<built-in function load_pem_private_key>
0.16813015937805176
CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)

unsafe_skip_rsa_key_validation=False, n_tries=1000, method=<functools._lru_cache_wrapper object at 0x102551170>
0.4363410472869873
CacheInfo(hits=999, misses=1, maxsize=None, currsize=1)

unsafe_skip_rsa_key_validation=True, n_tries=1000, method=<functools._lru_cache_wrapper object at 0x102551170>
0.09387397766113281
CacheInfo(hits=999, misses=1, maxsize=None, currsize=1)

From the results above, we can get most of the runtime improvement from caching serialization.load_pem_private_key (and maybe even private_key.private_bytes?), so I'd prefer to go down this route instead.

Thanks again for starting the conversation @amardatar, would love to hear your perspective.

amardatar commented 4 months ago

Hey @aranke - no worries, and thanks for the context - I didn't know that, but makes sense!

So my suggestion on using unsafe_skip_rsa_key_validation is really just predicated on the fact that the Snowflake Connector is already doing validation - which means it's being done twice at the moment, and arguably can be skipped once. That's not necessarily a good reason to do this though - I wouldn't expect the behaviour to change in the Snowflake Connector, but perhaps relying on that behaviour isn't ideal either.

In terms of caching - I did try a very basic implementation in dbt-snowflake, and didn't leave enough info in the comment above, but it's probably useful so I'll just run through it here.

Just summarising the below:

In more detail:

My first attempt with caching was just a straight-forward cache of the result of _get_private_key - and while that did roughly halve the time it took for a run, it didn't save as much time as it should. I believe the reason for this is because _get_private_key returns bytes, and providing that to the Snowflake Connector results in it re-validating the key. This makes sense - half the time because half the number of key validations.

I tried instead returning and using the response of serialization.load_pem_private_key - which is an instance of PrivateKeyTypes - but instead got the error TypeError: expected bytes-like object, not RSAPrivateKey with the stack trace pointing to this line.

I don't know enough about how this is all being used to know if it's okay to avoid serialising the key in some way - if that can be done, then I think the performance benefits will be realised since the key won't need to be re-validated (by dbt-snowflake or by the Snowflake Connector). Without doing that, caching the bytes still leaves a lot of performance on the table.

Maybe another option though would be to skip validation after the first time (and return an instance of PrivateKeyTypes from _get_private_key) - this means the very first usage by dbt-snowflake is validated, but subsequent usages are not. The main reason for caution that I can think of here is that I'm not sure how any serialised results of the class are being used, and it's possible that at some point the private key could be changed in the serialised form, and used (without validation) in a subsequent call.

FishtownBuildBot commented 4 months ago

Opened a new issue in dbt-labs/docs.getdbt.com: https://github.com/dbt-labs/docs.getdbt.com/issues/5779