I encountered the following errors when performing reconstructions with the latest version of MBIRJAX (v0.4.2, commit 02f9c77).
When I did the same reconstruction with the previous version of MBIRJAX (v0.4.1, commit 77b778c), everything was fine.
Here's the error log:
Computing Hessian diagonal
Starting VCD iterations
Traceback (most recent call last):
File "/depot/bouman/users/yang1467/mbirjax_applications/nsi/demo_nsi.py", line 78, in <module>
recon, recon_params = ct_model.recon(sino, weights=weights)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/depot/bouman/users/yang1467/mbirjax/mbirjax/tomography_model.py", line 536, in recon
recon, loss_vectors = self.vcd_recon(sinogram, partitions, partition_sequence, weights=weights,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/depot/bouman/users/yang1467/mbirjax/mbirjax/tomography_model.py", line 670, in vcd_recon
fm_rmse[i] = self.get_forward_model_loss(error_sinogram, sigma_y, weights)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/depot/bouman/users/yang1467/mbirjax/mbirjax/tomography_model.py", line 908, in get_forward_model_loss
(error_sinogram * error_sinogram) * (weights / avg_weight)))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
File "/scratch/gilbreth/yang1467/.conda/envs/mbirjax/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
return binary_op(*args)
^^^^^^^^^^^^^^^^
ValueError: Received incompatible devices for jitted computation. Got argument x1 of jax.numpy.multiply with shape float32[200,480,384] and device ids [0] on platform GPU and argument x2 of jax.numpy.multiply with shape float32[200,480,384] and device ids [0] on platform CPU
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
To reproduce the error, you can go to mbirjax_applications and run demo_nsi.py on Gilbreth.
Some more information on the dataset and hardware:
sino shape: (200, 480, 384). Recon shape: (480, 384, 384)
I encountered the following errors when performing reconstructions with the latest version of MBIRJAX (v0.4.2, commit 02f9c77). When I did the same reconstruction with the previous version of MBIRJAX (v0.4.1, commit 77b778c), everything was fine.
Here's the error log:
To reproduce the error, you can go to mbirjax_applications and run
demo_nsi.py
on Gilbreth.Some more information on the dataset and hardware: