sanchit-gandhi / whisper-jax

JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
Apache License 2.0
4.42k stars 383 forks source link

AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class' #34

Open silvacarl2 opened 1 year ago

silvacarl2 commented 1 year ago

I cannot seem to get rid of this on google colab:

AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'

silvacarl2 commented 1 year ago

LOL:

RuntimeError: As of JAX 0.4.0, JAX only supports TPU VMs, not the older Colab TPUs.

We recommend trying Kaggle Notebooks (https://www.kaggle.com/code, click on "New Notebook" near the top) which offer TPU VMs. You have to create an account, log in, and verify your account to get accelerator support. Once you do that, there's a new "TPU 1VM v3-8" accelerator option. This gives you a TPU notebook environment similar to Colab, but using the newer TPU VM architecture. This should be a less buggy, more performant, and overall better experience than the older TPU node architecture.

It is also possible to use Colab together with a self-hosted Jupyter kernel running on a Cloud TPU VM. See https://research.google.com/colaboratory/local-runtimes.html for details.

silvacarl2 commented 1 year ago

yes this is going to be rough:

Cloud TPU v3: $8.00 hour Cloud TPU v4: $12.88 per hour

silvacarl2 commented 1 year ago

by the way, it works perfect on Kaggle!

silvacarl2 commented 1 year ago

this is also depressing:

Creating TPU node "whisper-tpu-test-server" failed. Error: Request failed with unknown error

silvacarl2 commented 1 year ago

get this same error on TPU V2.XX

AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'

Except for Kallg,e all TPU regions are maxed out for TPU V3.xx

sanchit-gandhi commented 1 year ago

Hey @silvacarl2 - it's super sad that Google Colab is no longer really support its TPU architectures. I also initially tried writing the notebook in Colab, but got the same recommendation as you and so switched to Kaggle notebooks.

There is indeed a bit of a queue to get a TPU VM v3-8 with Kaggle, but once you do it's worth it (20 hours free per week is very generous).

sanchit-gandhi commented 1 year ago

Also preemptible TPUs will be cheaper if you're happy to accept the risk that you might lose a TPU in return for a lower cost

silvacarl2 commented 1 year ago

yes, this works great on Kaggle.

we wanted to test it out and benchmark it but we could not eyt get it to run on anything except Kaggle.

silvacarl2 commented 1 year ago

no worries. this is awesome work because it FLIES on Kaggle.

talipturkmen commented 1 year ago

you can upgrade jax in colab using !pip install --upgrade jax jaxlib it works for me after upgrade