cai4cai / torchsparsegradutils

A collection of utility functions to work with PyTorch sparse tensors
Apache License 2.0
24 stars 3 forks source link

Add support for batched sparse matrix with batched dense matrix multiplication #34

Closed tvercaut closed 1 year ago

tvercaut commented 1 year ago

bmm supports coo but not csr: https://github.com/pytorch/pytorch/issues/98675

Also, bmm likely faces memory issues in the backward pass (TBC): https://github.com/pytorch/pytorch/issues/41128

As such it would make sense to extend our SparseMatMul op to support batched operations: https://github.com/cai4cai/torchsparsegradutils/blob/main/torchsparsegradutils/sparse_matmul.py

tvercaut commented 1 year ago

vmap doesn't look very well supported for batched sparse matrices.

We may need a rough workaround along the lines of the following. Let us consider a simple batch of size 2 with 2 sparse matrices A1 and A2. We also have 2 dense matrices B1 and B2.

We can create a block diagonal sparse matrix A12 = [A1 0; 0 A2] and a stacked dense matrix B12 = [B1; B2], We then compute: AB12 = A12 B12 = [A1 0; 0 A2] [B1; B2] = [A1 B1; A2 B2] and finally reshape AB12 into AB1 = A1 B1 and AB2 = A2 B2

This seems to be the way taken by pytorch geometric: https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/sparse.py#L384-L420