Closed mayabechlerspeicher closed 3 months ago
Yes, PyG DataLoader
is just a wrapper around torch.utils.data.DataLoader
with a custom collate_fn
. As such, this is the only argument that cannot be overridden. Let me clarify this in the documentation.
Thank you. Could you please clarify why it should be restricted from being overridden? I believe it is a typical case where a data object has custom keys that one wants to batch differently than concatenation (e.g., when their dimensions do not allow concatenation).
If we would allow overriding collate_fn
in PyG's data loader, then this would mean it boils down to torch.utils.data.DataLoader
. In this case, I don't see a good reason why you shouldn't use the vanilla PyTorch DataLoader
in the first place.
Note that you can also customize concatenation by overriding Data.__cat_dim__
(see the advanced mini-batch tutorial in our documentation).
Thanks. Nonetheless, the standard DataLoader fails to add a dimension to the edge index as the edges are different sizes for different graphs.
So let's say I am not interested in the batching of the edge indexes in one huge graph, and I just want to wrap multiple graphs together, i.e., to stack the keys of the graphs in the batch, but the tensors of each key can be of different shapes (as in edge indexes). So the gradient computation will be done on the loss over the whole batch, but the forward pass will be done on each graph in the batch separately anyway (so GPU-wise it's not the most efficient it could be, but that's ok). Because the tensors are not of the same dimensions, you cannot contact them, so Data.__cat_dim__ would not help. what should I do in that case?
Do you mean you simply want to "batch" tensors together by stacking them in a list? I am not yet sure I understand, sorry.
Yes. So I have some costum keys in my Data object, that have different dimensions and I cannot stack them, I just want to put them in a list.
I see, that's indeed currently not possible. What we could do is to provide an option in Data
to restrict concatenation of certain attributes. Would this work for your use-case?
@mayabechlerspeicher You can utilize the exclude_keys
in the CustomBatch.from_data_list(...)
and add it to the batch
however you want.
from typing import List, Optional, Union, Sequence
import random
import torch
from typing_extensions import Self
from torch_geometric.data import Data, Batch, Dataset
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.utils import from_smiles
class CustomBatch(Batch):
@classmethod
def from_data_list(
cls,
data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
) -> Self:
batch, slice_dict, inc_dict = collate(
cls,
data_list=data_list,
increment=True,
add_batch=not isinstance(data_list[0], Batch),
follow_batch=follow_batch,
exclude_keys=exclude_keys,
)
batch._num_graphs = len(data_list) # type: ignore
batch._slice_dict = slice_dict # type: ignore
batch._inc_dict = inc_dict # type: ignore
if exclude_keys:
for key in exclude_keys:
setattr(batch, key, [getattr(d, key) for d in data_list])
return batch
class Collate:
def __init__(
self,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None
) -> None:
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def __call__(self, batch):
elem = batch[0]
if isinstance(elem, Data):
return CustomBatch.from_data_list(batch, self.follow_batch, self.exclude_keys)
class CustomDataLoader(torch.utils.data.DataLoader):
def __init__(
self,
dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
**kwargs,
):
# Remove for PyTorch Lightning:
kwargs.pop('collate_fn', None)
# Save for PyTorch Lightning < 1.6:
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
super().__init__(
dataset,
batch_size,
shuffle,
collate_fn=Collate(follow_batch, exclude_keys),
**kwargs,
)
if __name__ == '__main__':
smiles_list = [
'F/C=C/F',
'COC(=O)[C@@]1(Cc2ccccc2)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@H]2CN=C(SC)N21',
'CC1=C(CCN2CCC(CC2)C2=NOC3=C2C=CC(F)=C3)C(=O)N2CCCCC2=N1',
'[H][C@@]12[C@H]3CC[C@H](C3)[C@]1([H])C(=O)N(C[C@@H]1CCCC[C@H]1CN1CCN(CC1)C1=NSC3=CC=CC=C13)C2=O',
]
data_list = []
for smiles in smiles_list:
data = from_smiles(smiles)
data.mol_features = [1] * random.randint(2, 15)
data_list.append(data)
print(data)
print()
dl = CustomDataLoader(dataset=data_list, batch_size=2, exclude_keys=['mol_features'])
for batch in dl:
print(batch)
print(batch.mol_features)
Oh, you are right. Thanks for pointing this out. Completely forgot about this option :)
🐛 Describe the bug
Pyg DataLoader can receive a custom collate_fn as it extends the torch DataLoader, but in its constructor, it doesn't use the given collate_fn; instead, it always uses Collater. I'm not sure if this is a bug or if the documentation is wrong, but the Pyg documentation states that any parameter used in torch's DataLoader can be used with Pyg's DataLoader. Still, this collate_fn parameter cannot be used.
So, to actually use a custom collate_fn, do I have to Extend DataLoader to use the given collate_fn?
Thanks.
Versions
2.5.3