PyLops / pylops

PyLops – A Linear-Operator Library for Python
https://pylops.readthedocs.io
GNU Lesser General Public License v3.0
430 stars 102 forks source link

Feature: jax integration #590

Closed mrava87 closed 3 months ago

mrava87 commented 4 months ago

Motivation

This PR introduces a new backend in PyLops to enable using JAX arrays.

As a by-product of JAX-enabled operators, we inherit JAX features like jit, automatic differentiation, and automatic vectorization.

Highlights

mrava87 commented 4 months ago

Hey, great job :-)

I left just some nit ;-)

Thanks!

mrava87 commented 4 months ago

@cako you may want to look at this as supporting document https://github.com/PyLops/pylops_notebooks/blob/master/developement-cupy/Timing_CupyJAX.ipynb. It contains timing for most of the methods ported to the Jax backend and a comparison with numpy and cupy

mrava87 commented 3 months ago

@mrava87 nicely done! Going to leave it as approved, but please have a look at some of the comments and my commits.

By the way I ran the notebook... seems like Jax is generally slower than CuPy? Am I reading this wrong?

image

This is also what I see when running this both locally and on colab... my guess/suspicion is that when the operator has a limited number of steps all calling np/cp, cupy is already very well optimized so the jit of jax does not really do much... and for some reason the equivalent jax.numpy methods are apparently slower... in one case, where the operator matvec/rmatvec has a for...loop (NonStationaryConvolve1D), then jax seems to shine...

I read a lot about jax being very optimized for GPUs/TPUs so this was also an exercise to compare it with cupy, but so far what I observe is somehow that cupy is better ;)