Closed tvercaut closed 1 year ago
vmap doesn't look very well supported for batched sparse matrices.
We may need a rough workaround along the lines of the following. Let us consider a simple batch of size 2 with 2 sparse matrices A1 and A2. We also have 2 dense matrices B1 and B2.
We can create a block diagonal sparse matrix A12 = [A1 0; 0 A2] and a stacked dense matrix B12 = [B1; B2], We then compute: AB12 = A12 B12 = [A1 0; 0 A2] [B1; B2] = [A1 B1; A2 B2] and finally reshape AB12 into AB1 = A1 B1 and AB2 = A2 B2
This seems to be the way taken by pytorch geometric: https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/sparse.py#L384-L420
bmm supports coo but not csr: https://github.com/pytorch/pytorch/issues/98675
Also, bmm likely faces memory issues in the backward pass (TBC): https://github.com/pytorch/pytorch/issues/41128
As such it would make sense to extend our SparseMatMul op to support batched operations: https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/sparse_matmul.py