Open lcaquot94 opened 1 year ago
Hi all, is there anything I can do to help solving this issue ? I would really like to be able to use LightningDataModule on my project but I can not use it for inference because of this issue. Thanks a lot
I found a workaround:
When defining the Dataloaders, provide the batch_size as a named parameter.
So, DataLoader(self.dataset, self.batch_size)
becomes DataLoader(self.dataset, batch_size=self.batch_size)
.
This solution follows the logic of the comment found in pytorch_lightning/utilities/data.py:
if the dataloader was wrapped in a hook, only take arguments with default values and assume user passes their kwargs correctly
It's a workaround, but I believe the bug still persists and should be fixed in the future.
Bug description
When attempting to use the
trainer.predict
method with aCustomDataModule
, an error occurs related to theDataLoader
implementation. It seems that multiple values are being passed for thebatch_size
argument, which results in aTypeError
and ultimately terminates the program. Note: I don't know if it is related but I am using torch geometrics objects (see code below)What version are you seeing the problem on?
v2.0
How to reproduce the bug
Error messages and logs
File "C:\Users\Name.Surname\PycharmProjects\ai-developments.venv\lib\site-packages\pytorch_lightning\utilities\data.py", line 133, in _update_dataloader dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, dl_kwargs) File "C:\Users\Name.Surname\PycharmProjects\ai-developments.venv\lib\site-packages\lightning_fabric\utilities\data.py", line 280, in _reinstantiate_wrapped_cls raise MisconfigurationException(message) from e lightning_fabric.utilities.exceptions.MisconfigurationException: The DataLoader implementation has an error where more than one
__init__
argument can be passed to its parent'sbatch_size=...
__init__
argument. This is likely caused by allowing passing both a custom argument that will map to thebatch_size
argument as well as `kwargs.
kwargsshould be filtered to make sure they don't contain the
batch_size` key. This argument was automatically passed to your object by PyTorch Lightning.Environment
Current environment
``` #- Lightning Component: Trainer, LightningModule, LightningDataModule #- PyTorch Lightning Version: 2.0.2 #- PyTorch Version: 1.13.1+cpu #- Python version: 3.10.9 #- OS: Windows #- CUDA/cuDNN version: Not used #- GPU models and configuration: Not used #- How you installed Lightning(`conda`, `pip`, source): poetry #- Running environment of LightningApp (e.g. local, cloud): local ```More info
No response