rusty1s / pytorch_scatter

PyTorch Extension Library of Optimized Scatter Operations
https://pytorch-scatter.readthedocs.io
MIT License
1.5k stars 178 forks source link

scatter_max bug: always return out-of-(upper)bound index, value associated with it is 0 #438

Closed Chaoqi-LIU closed 1 month ago

Chaoqi-LIU commented 2 months ago

Hi, I'm using torch 2.1.0.post303 and torch_scatter 2.1.2 with cuda 12.2.

recall: scatter_max returns the value, and index associated with it.

The bug I encountered was the size of the src will always be included in the second return, i.e., indices, and the value associated with that index is 0.

This is my temporary fix:

  max_z, argmax_z = torch_scatter.scatter_max(in_bbox_particles_L[-2, :, 2], indices)
  bug_mask = argmax_z == indices.shape[0]
  max_z = max_z[~bug_mask]                 
  argmax_z = argmax_z[~bug_mask]

I tested with torch.unique to see if index with value of the size of the src was appeared, but no, so it's very likely scatter_max's fault.

beyond this, scatter_min has the same problem as well.

rusty1s commented 1 month ago

Do you mean that the argmax is filled with an invalid index in case the segment is empty? This is working as designed, and your solution is the correct way to handle this downstream.

Chaoqi-LIU commented 1 month ago

cool. thanks. didn't know it's designed to be so. :+1: