I installed Jax using the following command:
pip install jax==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_releases.html pip install jaxlib[cu102]==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_releases.html
The training error happened as follows:
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided. 2024-04-13 16:21:57.832 | ERROR | silk.cli:main:116 - run failed,*.logfile might be found in : var/silk-cli/run/training/2024-04-13/16-21-46
Could you please give me some suggestions or instructions?
Thanks for your time!
Jax cannot find GPU devices, but Pytorch (torch1.11 + cu102) can.
nvcc-V 10.2 jax 0.4.13 jaxlib 0.4.13
I installed Jax using the following command:
pip install jax==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_releases.html pip install jaxlib[cu102]==0.4.13 -f https://storage.googleapis.com/jax-releases/jax_releases.html
The training error happened as follows:jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided. 2024-04-13 16:21:57.832 | ERROR | silk.cli:main:116 - run failed,
*.logfile might be found in : var/silk-cli/run/training/2024-04-13/16-21-46
Could you please give me some suggestions or instructions? Thanks for your time!
Haolin