Closed ficstamas closed 1 year ago
Thank you for the detailed reproducing example.
At the first glance, it looks like the issue is in grad_out_fmt=tuple([(sten.KeepAll(), torch.Tensor, sten.KeepAll(), torch.Tensor)]),
inside SparseLinear
class. Basically, for each occurrence of sparse tensor in autograd graph, its gradient have to be wrapped inside SparseTensorWrapper. If this is not happening, autograd tries to operate on tensors of different types and breaks because of their shape mismatch. The easiest fix in your example is probably to replace the problematic string with grad_out_fmt=tuple([(sten.KeepAll(), torch.Tensor, sten.KeepAll(), sten.DenseTensor)]),
and continue implementing other things from that. DenseTensor
as the last argument should automatically wrap torch.Tensor into SparseTensorWrapper but still, keep the same dense data.
I see this API to be suboptimal and try to investigate it further, especially taking into account that our examples provide exactly such a template.
Thanks for your quick response. I'll try to implement it in my actual code sometimes early next week, but I believe it should work.
Description
I attempted to create some wrappers to simplify the use of
sten
and expand its functionality. Despite my efforts to work primarily with sparse tensors (which may be the source of the problem), during the backward pass, PyTorch tries to match the shape of the gradient with thedummy_shape
which fails for obvious reasons.Packages
Reproduction
It is straight forward to reproduce the error. Based on the SparseMLP example, I made a minimal "working" script. The only modification I made was to the
out_fmt
ofsparse_op
. Afterwards, I just implemented every forward method, and only one of the many backward methods because it fails before the next one is even called.In
SparseLinear.forward
I just changed thetorch.Tensor
s to a sparse tensor:The whole script:
Output
It should give the following output:
In addition, I discovered that the dummy_shape is recorded on the
loss.grad_fn
, which is expected as it results in an error afterregister_bwd_op_impl: torch.nn.functional.mse_loss
returns.Also the PyTorch part of the error message originates from here: https://github.com/pytorch/pytorch/blob/b95e1d76a86b7b66f0946f72ebd33889bfc19e03/torch/csrc/autograd/engine.cpp#L818
metadata.incompatible_shape_error_message(i, grad)
: https://github.com/pytorch/pytorch/blob/77c2a8a11f7b5164c255b5b49dbc66a3f6533e9d/torch/csrc/autograd/input_metadata.h#L91Do you think it is a problem with PyTorch or is it something I did incorrectly with STen?