Closed mrava87 closed 3 months ago
Hey, great job :-)
I left just some nit ;-)
Thanks!
@cako you may want to look at this as supporting document https://github.com/PyLops/pylops_notebooks/blob/master/developement-cupy/Timing_CupyJAX.ipynb. It contains timing for most of the methods ported to the Jax backend and a comparison with numpy and cupy
@mrava87 nicely done! Going to leave it as approved, but please have a look at some of the comments and my commits.
By the way I ran the notebook... seems like Jax is generally slower than CuPy? Am I reading this wrong?
This is also what I see when running this both locally and on colab... my guess/suspicion is that when the operator has a limited number of steps all calling np/cp, cupy is already very well optimized so the jit
of jax does not really do much... and for some reason the equivalent jax.numpy
methods are apparently slower... in one case, where the operator matvec/rmatvec has a for...loop (NonStationaryConvolve1D
), then jax seems to shine...
I read a lot about jax being very optimized for GPUs/TPUs so this was also an exercise to compare it with cupy, but so far what I observe is somehow that cupy is better ;)
Motivation
This PR introduces a new backend in PyLops to enable using JAX arrays.
As a by-product of JAX-enabled operators, we inherit JAX features like jit, automatic differentiation, and automatic vectorization.
Highlights
JaxOperator
backend
module with new logic to detect whether np,cp, or jnp methods should be used based on the input typejaxop
gpu.rst
documentation page