Open Ebanflo42 opened 1 year ago
No, I don't know of any implementation similar to bcoo_dot_general_sampled
with sparse operands.
I noticed that the function is actually expecting dense arrays
Yes, this is the intent of bcoo_dot_general_sampled
. It essentially implements the generalized SDDMM operation, which is important because it appears in the transpose of a sparse-dense dot product.
also slicing a sparse array appears to return a dense array?
I don't think this is true in general; for example:
In [1]: from jax.experimental import sparse
In [2]: import jax.numpy as jnp
In [3]: sparse.BCOO.fromdense(jnp.arange(10))[:5]
Out[3]: BCOO(int32[5], nse=5)
Are there any workarounds I missed?
Not that I know of – with some thought, you could probably implement this logic from scratch for the particular matrix layouts you're interested in.
Thanks for your fast response as usual!
I don't think this is true in general;
You are right. Taking that into account, I think the following general idea works:
>>> matrix = sparse.BCOO.fromdense(jnp.eye(10))
>>> matrix
BCOO(float32[10, 10], nse=10)
>>> sampled_product = sparse.BCOO(((matrix@matrix)[matrix.indices[:, 0], matrix.indices[:, 1]].data[:, 0], matrix.indices), shape=matrix.shape)
>>> sampled_product
BCOO(float32[10, 10], nse=10)
>>>
As long as the operands have the same indices and shape, this should work. Although it runs a few unnecessary multiplies in the general case, I think after JIT this will be pretty efficient?
I don't totally understand what you have in mind, but it looks like you're trying to remove the batch dimension in the indexing results. If that's the case, you could use result.update_layout(n_batch=0)
rather than manually creating the new BCOO matrix from the data and index buffers.
Hi,
I'm wondering if it is possible to efficiently implement a function analogous to
bcoo_dot_general_sampled
with sparse operands. I noticed that the function is actually expecting dense arrays, and also slicing a sparse array appears to return a dense array? My use case is simply trying to do online accumulation of gradients for a sparse recurrent network, with the sparsity constraint enforced on the gradients. Casting to and from a dense array in between would be way too slow and memory-intensive.Are there any workarounds I missed? How difficult would this be to implement?