Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

Using a non-named parameter for DataLoader initialization results in an error when using a LightningDataModule #17991

Open lcaquot94 opened 1 year ago

lcaquot94 commented 1 year ago

Bug description

When attempting to use the trainer.predict method with a CustomDataModule, an error occurs related to the DataLoader implementation. It seems that multiple values are being passed for the batch_size argument, which results in a TypeError and ultimately terminates the program. Note: I don't know if it is related but I am using torch geometrics objects (see code below)

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import torch.nn as nn
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from torch_geometric.loader import DataLoader

class CustomData(Data):
    # Very simple random graph data
    def __init__(self, ):
        x = torch.rand(1, 10, 16)  # 10 nodes of shape 16
        edge_index = torch.arange(0, 10).repeat(2, 1)  # Only self edges on each node
        super().__init__(x=x, edge_index=edge_index)

class CustomDataset(InMemoryDataset):
    def __init__(self, num_data=1000, root=None, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        data_list = [CustomData() for _ in range(num_data)]
        self.data, self.slices = self.collate(data_list)

class CustomDataModule(pl.LightningDataModule):

    def __init__(self):
        super().__init__()
        self.has_setup_fit = False
        self.has_setup_predict = False

    @property
    def batch_size(self):
        return 64

    def setup(self, stage: str):
        if not self.has_setup_fit and stage == 'fit':
            self.dataset = CustomDataset()
            self.has_setup_predict = True
        if not self.has_setup_predict and stage == 'predict':
            self.dataset = CustomDataset()
            self.has_setup_predict = True

    def train_dataloader(self):
        return DataLoader(self.dataset, self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.dataset, self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.dataset, self.batch_size)

class CustomNeuralNetwork(nn.Module):
    def __init__(self, input_size, layers_size):
        super().__init__()
        layers_list = [nn.Linear(input_size, layers_size[0]), nn.LeakyReLU()]
        for i in range(len(layers_size)-1):
            layers_list.extend([nn.Linear(layers_size[i], layers_size[i+1]), nn.LeakyReLU()])
        self.linears = nn.ModuleList(layers_list)

    def forward(self, x):
        for linear in self.linears:
            x = linear(x)
        return x

class CustomModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.neural_network = CustomNeuralNetwork(16, [64, 32, 16])

    def forward(self, data):
        output = self.neural_network(data.x)
        fake_loss = torch.mean(output**2)
        return fake_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

    def training_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

model = CustomModel()
trainer = Trainer(max_epochs=10)
dm = CustomDataModule()
trainer.fit(model, datamodule=dm)
trainer.predict(datamodule=dm)

Error messages and logs

File "C:\Users\Name.Surname\PycharmProjects\ai-developments.venv\lib\site-packages\pytorch_lightning\utilities\data.py", line 133, in _update_dataloader dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, dl_kwargs) File "C:\Users\Name.Surname\PycharmProjects\ai-developments.venv\lib\site-packages\lightning_fabric\utilities\data.py", line 280, in _reinstantiate_wrapped_cls raise MisconfigurationException(message) from e lightning_fabric.utilities.exceptions.MisconfigurationException: The DataLoader implementation has an error where more than one __init__ argument can be passed to its parent's batch_size=... __init__ argument. This is likely caused by allowing passing both a custom argument that will map to the batch_size argument as well as `kwargs.kwargsshould be filtered to make sure they don't contain thebatch_size` key. This argument was automatically passed to your object by PyTorch Lightning.

Environment

Current environment ``` #- Lightning Component: Trainer, LightningModule, LightningDataModule #- PyTorch Lightning Version: 2.0.2 #- PyTorch Version: 1.13.1+cpu #- Python version: 3.10.9 #- OS: Windows #- CUDA/cuDNN version: Not used #- GPU models and configuration: Not used #- How you installed Lightning(`conda`, `pip`, source): poetry #- Running environment of LightningApp (e.g. local, cloud): local ```

More info

No response

lcaquot94 commented 1 year ago

Hi all, is there anything I can do to help solving this issue ? I would really like to be able to use LightningDataModule on my project but I can not use it for inference because of this issue. Thanks a lot

lcaquot94 commented 1 year ago

I found a workaround: When defining the Dataloaders, provide the batch_size as a named parameter. So, DataLoader(self.dataset, self.batch_size) becomes DataLoader(self.dataset, batch_size=self.batch_size).

This solution follows the logic of the comment found in pytorch_lightning/utilities/data.py:

if the dataloader was wrapped in a hook, only take arguments with default values and assume user passes their kwargs correctly

It's a workaround, but I believe the bug still persists and should be fixed in the future.