Open joaospinto opened 2 months ago
Thanks for the question! JAX's sparse-sparse matmul is pretty inefficient, but for good reason: for completely unstructured sparsity when the indices are not known statically, that's the worst-case non-aggregated nse for the resulting matmul! For example, if the left hand matrix had a full column of entries, and the right hand matrix had a full row of entries, the result of the matmul would have a 100% fill-factor. And with non-static indices, there's no way to rule-out that worst case at compile time.
With static sparsity patterns, of course, you could rule-out this worst case and do better. JAX doesn't currently have any implementation of that sort of sparsity.
I feel that until something like that happens, the utility of the
sparse
module might be pretty limited. For example, this is somewhat preventing me from using sparse matrix inputs in https://github.com/joaospinto/jax_alqp.
I agree with your assessment here: jax.experimental.sparse
is very limited (it's one of the reasons it's still experimental!) and probably is not a good fit for your application.
Thanks for confirming. I hope the example I provided motivates future work in this direction. :)
FYI I added a prototype of what I needed here and here.
Given that in the dense case we require all shapes to be statically determined, it would not be unreasonable for the sparse case to demand that sparsity patterns be statically determined. Obviously, as shown by the code above, this is not something that needs to be natively supported, but it would be cool if it were.
Interesting! For what it's worth, this example illustrates the fundamental challenge with building a general sparse computing API. What is meant by "sparsity" is very specific to each particular context: in your case, only static sparsity patterns make sense. In other cases, static sparsity would not be possible.
One of the reasons we've not done much more work on jax.experiemental.sparse
is because of this: it's nearly impossible to write one general tool that serves most use-cases. Rather, people want specific sparse kernels for the particular constraints and sparsity structure of their own application, and so writing purpose-specific kernels is more effective in the long run.
The approach I described above does run into one issue. Every matrix operation I do gets unrolled into JAX's expression graph, making it blow up in size. This makes JIT compilation take north of 20 minutes in cases that aren't particularly complicated, and also makes the runtimes 10x-100x slower than they ought to be (and the MLIR BC files are over 1MB!). This is a bit surprising to me, as my code is doing something quite simple.
One solution would be to force the same (ideally dimension-independent) code flow (for things like sparse matrix multiplication) to be applied in different places, while still preserving statically-determined "number of non-zero" values for each sparse matrix. Should I be using a pure callback in this situation? Would be interested in hearing your thoughts.
I don't know of a good solution. XLA is not well set up for general unstructured sparse computation, unfortunately. Now, if you have block sparsity then you may be able to make some progress, because you can take advantage of vectorized operations for parts of your computation.
See https://github.com/google/jax/pull/23674 for an example of an efficient block-sparse approach for one particular operation of interest.
Suppose I have two sparse matrices:
If I do
G @ H
, I getHowever, the sparsity pattern of
G
andH
might have been statically known (this is certainly the case in my application), and the realnse
might be easy to statically upper-bound well below, in this case, the6400
shown above.I feel that until something like that happens, the utility of the
sparse
module might be pretty limited. For example, this is somewhat preventing me from using sparse matrix inputs in https://github.com/joaospinto/jax_alqp.What are your thoughts? Maybe @jakevdp , as I've seen you comment on other sparsity-related posts?