Describe the bug
When using the new GraphNeTDataModule without specifying a suitable collate_fn in train_dataloader_kwargs, the code fails with
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'>
To Reproduce
Use graphnet.data.datamodule.GraphNeTDataModule to produce dataloaders without explicitly passing collate_fn as an entry in train_dataloader_kwargs.
Full traceback
Please include the full error message to allow for debugging
Traceback (most recent call last):
File "mwe.py", line 138, in <module>
main(
File "mwe.py", line 88, in main
model.fit(
File "/home/iwsatlas1/oersoe/github/graphnet/src/graphnet/models/standard_model.py", line 167, in fit
trainer.fit(
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 531, in fit
call._call_and_handle_interrupt(
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 41, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 91, in launch
return function(*args, **kwargs)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 570, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 975, in _run
results = self._run_stage()
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1016, in _run_stage
self._run_sanity_check()
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_sanity_check
val_loop.run()
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decorator
return loop_run(self, *args, **kwargs)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 108, in run
batch, batch_idx, dataloader_idx = next(data_fetcher)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/fetchers.py", line 136, in __next__
self._fetch_next_batch(self.dataloader_iter)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/fetchers.py", line 150, in _fetch_next_batch
batch = next(iterator)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/utilities/combined_loader.py", line 284, in __next__
out = next(self._iterator)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/utilities/combined_loader.py", line 123, in __next__
out = next(self.iterators[0])
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'>
Additional context
This error is resolved if using the default collate_fn in graphnet.training.utils
Describe the bug When using the new
GraphNeTDataModule
without specifying a suitablecollate_fn
intrain_dataloader_kwargs
, the code fails withTypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'>
To Reproduce Use
graphnet.data.datamodule.GraphNeTDataModule
to produce dataloaders without explicitly passingcollate_fn
as an entry intrain_dataloader_kwargs
.Full traceback Please include the full error message to allow for debugging
Additional context This error is resolved if using the default collate_fn in
graphnet.training.utils