rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
https://pytorch-scatter.readthedocs.io
MIT License
1.55k stars 179 forks source link

segment_csr triggers shape error in backward pass when indptr does on cover src and reduction is mean #452

Open noahbadoa opened 2 months ago

noahbadoa commented 2 months ago

Minimum example

import torch_scatter
import torch

indptr = torch.tensor([0, 3, 7])
x = torch.randn(8, 11)
lin = torch.nn.Linear(11, 11)
opt = torch.optim.SGD(lin.parameters())

x = lin(x)
x = torch_scatter.segment_csr(x, indptr, reduce="mean")

loss = x.mean()
loss.backward()
Traceback (most recent call last):
  File "test.py", line 13, in <module>
    loss.backward()
  File "../.env/lib/python3.8/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "../.env/lib/python3.8/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File ../.env/lib/python3.8/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 0
rusty1s commented 2 months ago

Do you mean indptr = torch.tensor([0, 3, 8])?

noahbadoa commented 2 months ago

No i explicitly meant for the range of indptr not to be the same as x. Interestingly indptr = torch.tensor([1, 8]) works without issue.

rusty1s commented 2 months ago

Ok, but this is undefined behavior. If you wanna support [0, 3, 7], then you need to input x[:, :7].

noahbadoa commented 2 months ago

I don't think this is be undefined behavior, the correct value of gradient of src where indptr does not index src should be zeros. reduce="min" and reduce="max" already have this behavior. If this usage of indptr is not explicitly not supported it would be nice if either this is documented or the function fails in a more graceful way; raising an assert instead of returning uninitialized memory in the case of reduce=sum.