Open lrnzgiusti opened 1 year ago
We currently do not support vmap over sparse tensors. Could you tell us a bit more about your use case please? (cc @cpuhrsch)
TLDR: Filtering on topological spaces (i.e. graphs, simplicial/cell complexes) require K sparse-dense matmul ops (i.e. sum_k S^k • X • W_k ) where S^k is sparse.
My specific use case is to implement a differentiable filtering operation on topological spaces (take graphs as example, higher-order relational structure in the general case). By looking around it seems that the only way to do this is using a for loop like this:
out = torch.stack([Sk.mm(X).mm(W[k]) for k, Sk in enumerate(G.adj)], dim=0).sum(axis=0)
However, with vmap the for loop above is obsolete:
from functorch import vmap
mm = vmap(torch.sparse.mm, in_dims=(0, None))
comp = mm(S, X)
out = torch.bmm(comp, W).sum(axis=0)
Where S is a KxNxN sparse tensor, X is a NxFin dense matrix and W is a KxFinxFout dense tensor.
Maybe is too specific to my use case but I think it can be very useful for all the folks that are interested in machine learning on graphs.
We would also be interested in a performant vmap for sparse-dense matrix-vector multiplications.
We use the sparse matrix to represent different interpolations in MR data, for examplecin non-uniform FFTs or Volume-to-slice Projections. In both cases, it is much more convenient (and faster) to construct the sparse matrix once and use on matmull compared to other python-only approaches.
Hello @zou3519 , @samdow.
TLDR: I got the following error
UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation
.I was trying to use vmap for batching sparse-dense matrix multiplications :
which yields correct results but:
.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)
Are there plans to implement batching for this operation in the near future ?
Thanks