Closed kventinel closed 3 years ago
Thanks for the full repro colab! This seems to be more of a JAX issue with CUDA 11 colab runtime, e.g. the same error reproduces if you do just
from jax import random
a = random.split(random.PRNGKey(1), 2)
So I would suggest submitting this bug to https://github.com/google/jax
In the meantime, if CUDA 10 is OK, I believe just running
!pip install -q git+https://www.github.com/google/neural-tangents
should work.
I have an issue running jax.random.PRNGKey(0) and I can not downgrade cuda to 10 as it is incompatible with my nvidia driver. So what should I do in this case ? any workaround?
Could you be using an older JAX version by any chance? JAX seems to support CUDA 11: https://github.com/google/jax#pip-installation
This should no longer be an issue since Colab now has CUDA 11, see working example: https://colab.research.google.com/gist/romanngg/6ad518bfafeeceda1e41bb54f043fbf9/no-jax-error-with-cuda-11.ipynb
Please feel free to reopen if there are still other problems!
https://colab.research.google.com/drive/1VHzY55vHtMPsvXR302WoYYAJTj74jy1S?usp=sharing
After that have error: