cai4cai / torchsparsegradutils

A collection of utility functions to work with PyTorch sparse tensors
Apache License 2.0
24 stars 3 forks source link

Modify custom `torch.autograd.Function`s to use `setup_context` instead of passing `ctx` as an argument to `forward` #37

Open tvercaut opened 1 year ago

tvercaut commented 1 year ago

See https://pytorch.org/docs/master/notes/extending.func.html This should help make our code compatible with vmap

theo-barfoot commented 1 year ago

I have given this a go. See: SparseMatMul on the batched-mm-vmap branch.

This implementation only works for the forward pass of the COO batched mm.

It fails for:

The CSR methods fails (at least in the first instance) due to the function _add_batch_dim not supporting CSR tensors.