jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.45k stars 2.8k forks source link

Sparse matrix multiplication and output nse values #23576

Open joaospinto opened 2 months ago

joaospinto commented 2 months ago

Suppose I have two sparse matrices:

G of type BCOO(float64[200, 80], nse=129)
H of type BCOO(float64[80, 200], nse=129)

If I do G @ H, I get

DynamicJaxprTracer[BCOO(float64[80, 80], nse=6400)]

However, the sparsity pattern of G and H might have been statically known (this is certainly the case in my application), and the real nse might be easy to statically upper-bound well below, in this case, the 6400 shown above.

I feel that until something like that happens, the utility of the sparse module might be pretty limited. For example, this is somewhat preventing me from using sparse matrix inputs in https://github.com/joaospinto/jax_alqp.

What are your thoughts? Maybe @jakevdp , as I've seen you comment on other sparsity-related posts?

jakevdp commented 2 months ago

Thanks for the question! JAX's sparse-sparse matmul is pretty inefficient, but for good reason: for completely unstructured sparsity when the indices are not known statically, that's the worst-case non-aggregated nse for the resulting matmul! For example, if the left hand matrix had a full column of entries, and the right hand matrix had a full row of entries, the result of the matmul would have a 100% fill-factor. And with non-static indices, there's no way to rule-out that worst case at compile time.

With static sparsity patterns, of course, you could rule-out this worst case and do better. JAX doesn't currently have any implementation of that sort of sparsity.

I feel that until something like that happens, the utility of the sparse module might be pretty limited. For example, this is somewhat preventing me from using sparse matrix inputs in https://github.com/joaospinto/jax_alqp.

I agree with your assessment here: jax.experimental.sparse is very limited (it's one of the reasons it's still experimental!) and probably is not a good fit for your application.

joaospinto commented 2 months ago

Thanks for confirming. I hope the example I provided motivates future work in this direction. :)

joaospinto commented 1 month ago

FYI I added a prototype of what I needed here and here.

Given that in the dense case we require all shapes to be statically determined, it would not be unreasonable for the sparse case to demand that sparsity patterns be statically determined. Obviously, as shown by the code above, this is not something that needs to be natively supported, but it would be cool if it were.

jakevdp commented 1 month ago

Interesting! For what it's worth, this example illustrates the fundamental challenge with building a general sparse computing API. What is meant by "sparsity" is very specific to each particular context: in your case, only static sparsity patterns make sense. In other cases, static sparsity would not be possible.

One of the reasons we've not done much more work on jax.experiemental.sparse is because of this: it's nearly impossible to write one general tool that serves most use-cases. Rather, people want specific sparse kernels for the particular constraints and sparsity structure of their own application, and so writing purpose-specific kernels is more effective in the long run.

joaospinto commented 1 month ago

The approach I described above does run into one issue. Every matrix operation I do gets unrolled into JAX's expression graph, making it blow up in size. This makes JIT compilation take north of 20 minutes in cases that aren't particularly complicated, and also makes the runtimes 10x-100x slower than they ought to be (and the MLIR BC files are over 1MB!). This is a bit surprising to me, as my code is doing something quite simple.

One solution would be to force the same (ideally dimension-independent) code flow (for things like sparse matrix multiplication) to be applied in different places, while still preserving statically-determined "number of non-zero" values for each sparse matrix. Should I be using a pure callback in this situation? Would be interested in hearing your thoughts.

jakevdp commented 1 month ago

I don't know of a good solution. XLA is not well set up for general unstructured sparse computation, unfortunately. Now, if you have block sparsity then you may be able to make some progress, because you can take advantage of vectorized operations for parts of your computation.

jakevdp commented 1 month ago

See https://github.com/google/jax/pull/23674 for an example of an efficient block-sparse approach for one particular operation of interest.