Open tvercaut opened 1 year ago
I have given this a go. See: SparseMatMul on the batched-mm-vmap branch.
This implementation only works for the forward pass of the COO batched mm.
It fails for:
The CSR methods fails (at least in the first instance) due to the function _add_batch_dim not supporting CSR tensors.
See https://pytorch.org/docs/master/notes/extending.func.html This should help make our code compatible with
vmap