lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
98 stars 17 forks source link

Replace manual GPU<->host copies in `scico.solver` with host callback #204

Closed Michael-T-McCann closed 2 years ago

Michael-T-McCann commented 2 years ago

In order to use scicpy.minimize, we currently manual copy arrays back to the host https://github.com/lanl/scico/blob/ec3dea364b03b63578e962cbe90f0d25db7a4fb8/scico/solver.py#L164-L166 and send the result back to the GPU if needed https://github.com/lanl/scico/blob/ec3dea364b03b63578e962cbe90f0d25db7a4fb8/scico/solver.py#L200-L201

A cleaner solution would be to use jax's host callback mechanism, which we use, e.g., for astra and svmbir. It is simpler in the case, as we do not need custom_vjp, etc. Also note that the current code probably won't work correctly on multi-GPU systems.

bwohlberg commented 2 years ago

Possible related is this cryptic comment in the same function

https://github.com/lanl/scico/blob/982e86e17911c6c1713654d9371a47ae6467c0db/scico/solver.py#L196

I assume the comment was intended to be something like "if x0 was originally a BlockArray res.x should be converted back to one here". Let's make clearing this up part of the resolution of this issue.

bwohlberg commented 2 years ago

Comment by @lukepfister from #96 copied here since that issue has been closed again:

Also, to give some more context-- that TODO is concerning device placement. So if you have some array x on device 0, and you run the solver, does the result go back to device 0? Per this line https://github.com/lanl/scico/blob/ec3dea364b03b63578e962cbe90f0d25db7a4fb8/scico/solver.py#L200 it should, but there are no actual tests to check that

I'd also be concerned that this code might not work as expected if the array is sharded across multiple devices, but we don't really deal with that in the library

lukepfister commented 2 years ago

Possible related is this cryptic comment in the same function

https://github.com/lanl/scico/blob/982e86e17911c6c1713654d9371a47ae6467c0db/scico/solver.py#L196

I assume the comment was intended to be something like "if x0 was originally a BlockArray res.x should be converted back to one here". Let's make clearing this up part of the resolution of this issue.

snp.reshape will return a BlockArray if the shape arg is a tuple of tuples. so the correct comment would be

"if x0 was originally a BlockArray then res.x is converted back to one here".

bwohlberg commented 2 years ago

Comment corrected in 95658bc.

bwohlberg commented 2 years ago

How much of the functionality in this module could be replaced by introducing a dependency on jaxopt, i.e. since we're planning on doing so anyway (see #196), is there any point in improving scico.solver, or does it make more sense to replace its functionality with jaxopt?

lukepfister commented 2 years ago

quick thoughts, things to check for:

bwohlberg commented 2 years ago

Is this resolved by #253? If so, it would be best to mark as such.