rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
https://pytorch-scatter.readthedocs.io
MIT License
1.5k stars 178 forks source link

scatter_logsumexp broken for untouched outputs #444

Closed Pierre-Bartet closed 1 month ago

Pierre-Bartet commented 1 month ago

The following code incorrectly set an unused output element to 0 instead of letting it to its original value (here -10):

import torch
from torch_scatter import scatter_logsumexp

src = torch.tensor([-1., -50])
index = torch.tensor([0, 0])

out = torch.full((2,), -10.)

scatter_logsumexp(src=src, index=index, out=out)
# tensor([-0.9999,  0.0000]) instead of tensor([-0.9999,  -10])

which mean scatter_logsumexp only works in the corner case where all outputs are affected by the scatter operation.

rusty1s commented 1 month ago

Thanks. Will be fixed in https://github.com/rusty1s/pytorch_scatter/pull/445.

Pierre-Bartet commented 1 month ago

Thanks for the quick fix (and the whole library)! Looking at the code I see that there is a return out.nan_to_num_(neginf=0.0), but - inf is a perfectly legit output value. Is there something preventing the correct handling of these cases https://github.com/rusty1s/pytorch_scatter/pull/426 and https://github.com/rusty1s/pytorch_scatter/issues/407 ?