google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.13k stars 2.67k forks source link

Subsampling sparse matrices during (or after) multiplication #17312

Open Ebanflo42 opened 10 months ago

Ebanflo42 commented 10 months ago

Hi,

I'm wondering if it is possible to efficiently implement a function analogous to bcoo_dot_general_sampled with sparse operands. I noticed that the function is actually expecting dense arrays, and also slicing a sparse array appears to return a dense array? My use case is simply trying to do online accumulation of gradients for a sparse recurrent network, with the sparsity constraint enforced on the gradients. Casting to and from a dense array in between would be way too slow and memory-intensive.

Are there any workarounds I missed? How difficult would this be to implement?

jakevdp commented 10 months ago

No, I don't know of any implementation similar to bcoo_dot_general_sampled with sparse operands.

I noticed that the function is actually expecting dense arrays

Yes, this is the intent of bcoo_dot_general_sampled. It essentially implements the generalized SDDMM operation, which is important because it appears in the transpose of a sparse-dense dot product.

also slicing a sparse array appears to return a dense array?

I don't think this is true in general; for example:

In [1]: from jax.experimental import sparse

In [2]: import jax.numpy as jnp

In [3]: sparse.BCOO.fromdense(jnp.arange(10))[:5]
Out[3]: BCOO(int32[5], nse=5)

Are there any workarounds I missed?

Not that I know of – with some thought, you could probably implement this logic from scratch for the particular matrix layouts you're interested in.

Ebanflo42 commented 10 months ago

Thanks for your fast response as usual!

I don't think this is true in general;

You are right. Taking that into account, I think the following general idea works:

>>> matrix = sparse.BCOO.fromdense(jnp.eye(10))
>>> matrix
BCOO(float32[10, 10], nse=10)
>>> sampled_product = sparse.BCOO(((matrix@matrix)[matrix.indices[:, 0], matrix.indices[:, 1]].data[:, 0], matrix.indices), shape=matrix.shape)
>>> sampled_product
BCOO(float32[10, 10], nse=10)
>>> 

As long as the operands have the same indices and shape, this should work. Although it runs a few unnecessary multiplies in the general case, I think after JIT this will be pretty efficient?

jakevdp commented 10 months ago

I don't totally understand what you have in mind, but it looks like you're trying to remove the batch dimension in the indexing results. If that's the case, you could use result.update_layout(n_batch=0) rather than manually creating the new BCOO matrix from the data and index buffers.