Closed zohimchandani closed 1 year ago
How did you install JAX? pip command line? What is your environment/container? XLA need ptxas and it doesn't find it. So maybe it isn't installed, or it is installed at a place that XLA doesn't find.
Fix below
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Running the following code snippet and getting an error: