pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Performance drop for the batching rule for aten::_sparse_mm #1075

Open lrnzgiusti opened 1 year ago

lrnzgiusti commented 1 year ago

Hello @zou3519 , @samdow.

TLDR: I got the following error UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation.

I was trying to use vmap for batching sparse-dense matrix multiplications :

from functorch import vmap

A = tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1],
                           [0, 1, 2, 3, 0, 1, 2, 3],
                           [0, 1, 2, 3, 0, 1, 2, 3]]),
                   values=tensor([1., 1., 1., 1., 2., 2., 2., 2.]),
                   size=(2, 4, 4), nnz=8, layout=torch.sparse_coo)

X = tensor([[-0.0533, -1.3950, -0.2621],
            [-1.0800,  0.3210,  0.7954],
            [ 0.7737,  0.3655,  0.5691],
            [-0.3505, -1.0423, -2.0650]])

bspmm = vmap(torch.sparse.mm, in_dims=(0, None))
Z = bspmm(A,X) 
In [1]: A.shape
Out[1]: torch.Size([2, 4, 4])

In [2]: X.shape
Out[2]: torch.Size([4, 3])

In [3]: Z.shape
Out[3]: torch.Size([2, 4, 3])

which yields correct results but:

.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)

Are there plans to implement batching for this operation in the near future ?

Thanks

zou3519 commented 1 year ago

We currently do not support vmap over sparse tensors. Could you tell us a bit more about your use case please? (cc @cpuhrsch)

lrnzgiusti commented 1 year ago

TLDR: Filtering on topological spaces (i.e. graphs, simplicial/cell complexes) require K sparse-dense matmul ops (i.e. sum_k S^k • X • W_k ) where S^k is sparse.

My specific use case is to implement a differentiable filtering operation on topological spaces (take graphs as example, higher-order relational structure in the general case). By looking around it seems that the only way to do this is using a for loop like this:

out = torch.stack([Sk.mm(X).mm(W[k]) for k, Sk in enumerate(G.adj)], dim=0).sum(axis=0)

However, with vmap the for loop above is obsolete:

from functorch import vmap
mm = vmap(torch.sparse.mm, in_dims=(0, None))

comp = mm(S, X)
out = torch.bmm(comp, W).sum(axis=0)

Where S is a KxNxN sparse tensor, X is a NxFin dense matrix and W is a KxFinxFout dense tensor.

Maybe is too specific to my use case but I think it can be very useful for all the folks that are interested in machine learning on graphs.

fzimmermann89 commented 3 weeks ago

We would also be interested in a performant vmap for sparse-dense matrix-vector multiplications.

We use the sparse matrix to represent different interpolations in MR data, for examplecin non-uniform FFTs or Volume-to-slice Projections. In both cases, it is much more convenient (and faster) to construct the sparse matrix once and use on matmull compared to other python-only approaches.