lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
105 stars 17 forks source link

Example ct_svmbir_tv_multi.py broken on GPU device #136

Closed bwohlberg closed 2 years ago

bwohlberg commented 2 years ago

Example script ct_svmbir_tv_multi.py runs without problem on a CPU, but fails when run on a GPU:

$ python examples/scripts/ct_svmbir_tv_multi.py 
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  Primal Rsdl  Dual Rsdl
-------------------------------------------------
   0  6.31e+00  9.014e+00    6.612e-01  5.694e-01
  10  5.17e+01  1.618e+01    2.415e-02  2.954e-02
  20  9.58e+01  1.628e+01    1.147e-02  1.347e-02
  30  1.40e+02  1.632e+01    7.485e-03  8.041e-03
  40  1.81e+02  1.634e+01    5.548e-03  5.379e-03
  49  2.12e+02  1.635e+01    4.512e-03  3.832e-03
PSNR: 22.92 dB

Iter  Time      Objective  Primal Rsdl  Dual Rsdl
-------------------------------------------------
   0  8.07e-01  5.399e+00    1.072e+00  1.104e+00
Traceback (most recent call last):
  File "examples/scripts/ct_svmbir_tv_multi.py", line 127, in <module>
    x_ladmm = solver_ladmm.solve()
  File "scico/scico/optimize/_ladmm.py", line 340, in solve
    self.step()
  File "scico/scico/optimize/_ladmm.py", line 314, in step
    self.x = self.f.prox(proxarg, self.mu, v0=self.x)
  File "scico/scico/linop/radon_svmbir.py", line 239, in prox
    raise ValueError("Result contains NaNs")
ValueError: Result contains NaNs

It's worth noting that

  1. it fails in the Linearized ADMM solve, after the ADMM solve completes without problems, and
  2. the ADMM solve uses the svmbir prox, while the Linearized ADMM solve uses the svmbir forward and backward operators.

I suspect that this problem may be related to the svmbir mask for the operators.

smajee commented 2 years ago

I ran ct_svmbir_tv_multi.py with GPU and it seemed to run fine. It printed out some warnings but the reconstructions were fine.

My jax versions are: jax 0.2.26 pypi_0 pypi jaxlib 0.1.75+cuda11.cudnn82 pypi_0 pypi

cuda: 11.4 cudnn: 8.2.4.15

@bwohlberg Could it be because I uninstalled and reinstalled jax instead of using the upgrade command?

Warnings:

WARNING *** You are using ptxas 11.0.221, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Output:

Screen Shot 2021-12-17 at 6 10 58 PM
bwohlberg commented 2 years ago

Strange. I guess it's possible that the way you upgraded jax may play a role. I did not see the ptxas warnings when I tried it.

smajee commented 2 years ago

It seems strange that the error depends on CPU vs GPU and the kind of GPU installation of jaxlib since svmbir uses CPU. The only issue I can think of is the data conversion between GPU device array and numpy-array, but seems unlikely.

bwohlberg commented 2 years ago

Agreed. I will try again to confirm that it wasn't just some temporary environment problem.

smajee commented 2 years ago

I ran the example again with an old commit (49b5b04) and an older version of jax (jaxlib==0.1.70, jax==0.2.19) and it ran fine. Then I used Mike's jax upgrade command pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html with the current code and it still ran fine.

bwohlberg commented 2 years ago

The problem does indeed seem to have been due to a corrupted jaxlib/jax installation: cleaning up old versions and re-installing seems to have resolved it. Strange, though, that it only manifested in interaction with svmbir. Closing the issues as resolved.