Closed bwohlberg closed 2 years ago
I ran ct_svmbir_tv_multi.py
with GPU and it seemed to run fine. It printed out some warnings but the reconstructions were fine.
My jax versions are: jax 0.2.26 pypi_0 pypi jaxlib 0.1.75+cuda11.cudnn82 pypi_0 pypi
cuda: 11.4 cudnn: 8.2.4.15
@bwohlberg Could it be because I uninstalled and reinstalled jax instead of using the upgrade command?
Warnings:
WARNING *** You are using ptxas 11.0.221, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.
You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Output:
Strange. I guess it's possible that the way you upgraded jax may play a role. I did not see the ptxas
warnings when I tried it.
It seems strange that the error depends on CPU vs GPU and the kind of GPU installation of jaxlib since svmbir uses CPU. The only issue I can think of is the data conversion between GPU device array and numpy-array, but seems unlikely.
Agreed. I will try again to confirm that it wasn't just some temporary environment problem.
I ran the example again with an old commit (49b5b04) and an older version of jax (jaxlib==0.1.70, jax==0.2.19) and it ran fine. Then I used Mike's jax upgrade command pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
with the current code and it still ran fine.
The problem does indeed seem to have been due to a corrupted jaxlib/jax installation: cleaning up old versions and re-installing seems to have resolved it. Strange, though, that it only manifested in interaction with svmbir. Closing the issues as resolved.
Example script
ct_svmbir_tv_multi.py
runs without problem on a CPU, but fails when run on a GPU:It's worth noting that
svmbir
prox, while the Linearized ADMM solve uses thesvmbir
forward and backward operators.I suspect that this problem may be related to the
svmbir
mask for the operators.