jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.58k stars 2.81k forks source link

Subsampling sparse matrices during (or after) multiplication #17312

Open Ebanflo42 opened 1 year ago

Ebanflo42 commented 1 year ago

Hi,

I'm wondering if it is possible to efficiently implement a function analogous to bcoo_dot_general_sampled with sparse operands. I noticed that the function is actually expecting dense arrays, and also slicing a sparse array appears to return a dense array? My use case is simply trying to do online accumulation of gradients for a sparse recurrent network, with the sparsity constraint enforced on the gradients. Casting to and from a dense array in between would be way too slow and memory-intensive.

Are there any workarounds I missed? How difficult would this be to implement?

jakevdp commented 1 year ago

No, I don't know of any implementation similar to bcoo_dot_general_sampled with sparse operands.

I noticed that the function is actually expecting dense arrays

Yes, this is the intent of bcoo_dot_general_sampled. It essentially implements the generalized SDDMM operation, which is important because it appears in the transpose of a sparse-dense dot product.

also slicing a sparse array appears to return a dense array?

I don't think this is true in general; for example:

In [1]: from jax.experimental import sparse

In [2]: import jax.numpy as jnp

In [3]: sparse.BCOO.fromdense(jnp.arange(10))[:5]
Out[3]: BCOO(int32[5], nse=5)

Are there any workarounds I missed?

Not that I know of – with some thought, you could probably implement this logic from scratch for the particular matrix layouts you're interested in.

Ebanflo42 commented 1 year ago

Thanks for your fast response as usual!

I don't think this is true in general;

You are right. Taking that into account, I think the following general idea works:

>>> matrix = sparse.BCOO.fromdense(jnp.eye(10))
>>> matrix
BCOO(float32[10, 10], nse=10)
>>> sampled_product = sparse.BCOO(((matrix@matrix)[matrix.indices[:, 0], matrix.indices[:, 1]].data[:, 0], matrix.indices), shape=matrix.shape)
>>> sampled_product
BCOO(float32[10, 10], nse=10)
>>> 

As long as the operands have the same indices and shape, this should work. Although it runs a few unnecessary multiplies in the general case, I think after JIT this will be pretty efficient?

jakevdp commented 1 year ago

I don't totally understand what you have in mind, but it looks like you're trying to remove the batch dimension in the indexing results. If that's the case, you could use result.update_layout(n_batch=0) rather than manually creating the new BCOO matrix from the data and index buffers.