graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 85 forks source link

`TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'>` #685

Closed RasmusOrsoe closed 3 months ago

RasmusOrsoe commented 3 months ago

Describe the bug When using the new GraphNeTDataModule without specifying a suitable collate_fn in train_dataloader_kwargs, the code fails with

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'>

To Reproduce Use graphnet.data.datamodule.GraphNeTDataModule to produce dataloaders without explicitly passing collate_fn as an entry in train_dataloader_kwargs.

Full traceback Please include the full error message to allow for debugging

Traceback (most recent call last):
  File "mwe.py", line 138, in <module>
    main(
  File "mwe.py", line 88, in main
    model.fit(
  File "/home/iwsatlas1/oersoe/github/graphnet/src/graphnet/models/standard_model.py", line 167, in fit
    trainer.fit(
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 531, in fit
    call._call_and_handle_interrupt(
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 41, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 91, in launch
    return function(*args, **kwargs)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 570, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 975, in _run
    results = self._run_stage()
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1016, in _run_stage
    self._run_sanity_check()
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_sanity_check
    val_loop.run()
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 108, in run
    batch, batch_idx, dataloader_idx = next(data_fetcher)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/fetchers.py", line 136, in __next__
    self._fetch_next_batch(self.dataloader_iter)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/fetchers.py", line 150, in _fetch_next_batch
    batch = next(iterator)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/utilities/combined_loader.py", line 284, in __next__
    out = next(self._iterator)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/utilities/combined_loader.py", line 123, in __next__
    out = next(self.iterators[0])
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/iwsatlas1/oersoe/anaconda3/envs/graphnet/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'>

Additional context This error is resolved if using the default collate_fn in graphnet.training.utils