cabouman / mbirjax

MBIRJAX is a Python package for Model Based Iterative Reconstruction (MBIR) of images from tomographic data.
https://mbirjax.readthedocs.io
BSD 3-Clause "New" or "Revised" License
10 stars 4 forks source link

ValueError with MBIRJAX v0.4.2 #39

Closed dyang37 closed 3 months ago

dyang37 commented 3 months ago

I encountered the following errors when performing reconstructions with the latest version of MBIRJAX (v0.4.2, commit 02f9c77). When I did the same reconstruction with the previous version of MBIRJAX (v0.4.1, commit 77b778c), everything was fine.

Here's the error log:

Computing Hessian diagonal
Starting VCD iterations
Traceback (most recent call last):
  File "/depot/bouman/users/yang1467/mbirjax_applications/nsi/demo_nsi.py", line 78, in <module>
    recon, recon_params = ct_model.recon(sino, weights=weights)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/depot/bouman/users/yang1467/mbirjax/mbirjax/tomography_model.py", line 536, in recon
    recon, loss_vectors = self.vcd_recon(sinogram, partitions, partition_sequence, weights=weights,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/depot/bouman/users/yang1467/mbirjax/mbirjax/tomography_model.py", line 670, in vcd_recon
    fm_rmse[i] = self.get_forward_model_loss(error_sinogram, sigma_y, weights)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/depot/bouman/users/yang1467/mbirjax/mbirjax/tomography_model.py", line 908, in get_forward_model_loss
    (error_sinogram * error_sinogram) * (weights / avg_weight)))
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
  File "/scratch/gilbreth/yang1467/.conda/envs/mbirjax/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
    return binary_op(*args)
           ^^^^^^^^^^^^^^^^
ValueError: Received incompatible devices for jitted computation. Got argument x1 of jax.numpy.multiply with shape float32[200,480,384] and device ids [0] on platform GPU and argument x2 of jax.numpy.multiply with shape float32[200,480,384] and device ids [0] on platform CPU
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

To reproduce the error, you can go to mbirjax_applications and run demo_nsi.py on Gilbreth.

Some more information on the dataset and hardware:

cabouman commented 3 months ago

Fixed bug with new release of v0.4.3