pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.29k stars 3.65k forks source link

When accessing added graph data from batch it has been incremented #6135

Open ProfDoof opened 1 year ago

ProfDoof commented 1 year ago

🐛 Describe the bug

Hi, I am a little buried under work right now so I don't have time to make an MVE. However, the basic description of the bug is that when I add an extra piece of information to a Data object like so

data = Data(...)
data.thing = 3

and do a dataset of those where thing could be any value within some range. I then use this dataset in y'all's DataLoader. Then, when I go to access thing from the Batch the values are above what they should be which makes the value fall outside of the expected range which is detrimental in my particular use case. It appears to be from _collate assuming that all things should be incremented which causes my thing to be incremented by some cumulative sum value. I'm not sure what I should do to remedy this.

Environment

rusty1s commented 1 year ago

This should actually only happen in case your attribute contains _index in its name. See here for a basic example:

import torch

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

data = Data(x=torch.randn(10, 16), thing=1, thing_index=1)

loader = DataLoader([data, data, data], batch_size=3)

batch = next(iter(loader))
print(batch)
print(batch.thing)
print(batch.thing_index)

DataBatch(x=[30, 16], thing=[3], thing_index=[3], batch=[30], ptr=[4])
tensor([1, 1, 1])
tensor([ 1, 11, 21])