Open ProfDoof opened 1 year ago
This should actually only happen in case your attribute contains _index
in its name. See here for a basic example:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
data = Data(x=torch.randn(10, 16), thing=1, thing_index=1)
loader = DataLoader([data, data, data], batch_size=3)
batch = next(iter(loader))
print(batch)
print(batch.thing)
print(batch.thing_index)
DataBatch(x=[30, 16], thing=[3], thing_index=[3], batch=[30], ptr=[4])
tensor([1, 1, 1])
tensor([ 1, 11, 21])
🐛 Describe the bug
Hi, I am a little buried under work right now so I don't have time to make an MVE. However, the basic description of the bug is that when I add an extra piece of information to a Data object like so
and do a dataset of those where
thing
could be any value within some range. I then use this dataset in y'all'sDataLoader
. Then, when I go to accessthing
from theBatch
the values are above what they should be which makes the value fall outside of the expected range which is detrimental in my particular use case. It appears to be from _collate assuming that all things should be incremented which causes my thing to be incremented by some cumulative sum value. I'm not sure what I should do to remedy this.Environment
conda
,pip
, source): condatorch-scatter
):