Closed Michael-T-McCann closed 2 years ago
Possible related is this cryptic comment in the same function
https://github.com/lanl/scico/blob/982e86e17911c6c1713654d9371a47ae6467c0db/scico/solver.py#L196
I assume the comment was intended to be something like "if x0 was originally a BlockArray res.x should be converted back to one here". Let's make clearing this up part of the resolution of this issue.
Comment by @lukepfister from #96 copied here since that issue has been closed again:
Also, to give some more context-- that TODO is concerning device placement. So if you have some array x
on device 0, and you run the solver, does the result go back to device 0? Per this line https://github.com/lanl/scico/blob/ec3dea364b03b63578e962cbe90f0d25db7a4fb8/scico/solver.py#L200 it should, but there are no actual tests to check that
I'd also be concerned that this code might not work as expected if the array is sharded across multiple devices, but we don't really deal with that in the library
Possible related is this cryptic comment in the same function
https://github.com/lanl/scico/blob/982e86e17911c6c1713654d9371a47ae6467c0db/scico/solver.py#L196
I assume the comment was intended to be something like "if x0 was originally a BlockArray res.x should be converted back to one here". Let's make clearing this up part of the resolution of this issue.
snp.reshape
will return a BlockArray if the shape arg is a tuple of tuples. so the correct comment would be
"if x0 was originally a BlockArray then res.x is converted back to one here".
Comment corrected in 95658bc.
How much of the functionality in this module could be replaced by introducing a dependency on jaxopt
, i.e. since we're planning on doing so anyway (see #196), is there any point in improving scico.solver
, or does it make more sense to replace its functionality with jaxopt
?
quick thoughts, things to check for:
Is this resolved by #253? If so, it would be best to mark as such.
In order to use
scicpy.minimize
, we currently manual copy arrays back to the host https://github.com/lanl/scico/blob/ec3dea364b03b63578e962cbe90f0d25db7a4fb8/scico/solver.py#L164-L166 and send the result back to the GPU if needed https://github.com/lanl/scico/blob/ec3dea364b03b63578e962cbe90f0d25db7a4fb8/scico/solver.py#L200-L201A cleaner solution would be to use jax's host callback mechanism, which we use, e.g., for
astra
andsvmbir
. It is simpler in the case, as we do not needcustom_vjp
, etc. Also note that the current code probably won't work correctly on multi-GPU systems.