Open basit-7 opened 1 year ago
I'd check out the Jax installation docs to see what their guidance is. You'll probably need to reinstall Jax with CUDA/GPU support enabled.
@basit-7
https://github.com/google/jax#installation
just install jax with gpu support:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
After installation, mine failed on test script
Same here... does anyone have any solutions for it? Thanks!
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
this worked for me
When I run all unit tests, I get the following error.
I am working with Quadro RTX 8000, cuda 11.0.