Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.01k stars 3.36k forks source link

Overfit batches parameter gives a validation batch #15021

Open HekpoMaH opened 1 year ago

HekpoMaH commented 1 year ago

Bug description

When overfitting on a single batch and defining dataloaders in class, the batch provided to the validation step is different from the batch on the training step. I was told in the slack community that this is NOT the intended behaviour.

How to reproduce the bug

import pytorch_lightning as pl
import torch_geometric
import torch

dataset = [torch_geometric.data.Data(x=torch.tensor([i])) for i in range(10)]
val_dataset = [torch_geometric.data.Data(x=torch.tensor([j])) for j in range(10,20)]
class LitModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.tensor([0.]))
    def train_dataloader(self):
        return torch_geometric.loader.DataLoader(dataset, batch_size=2)
    def val_dataloader(self):
        return torch_geometric.loader.DataLoader(val_dataset, batch_size=2)

    def training_step(self, batch, batch_idx):
        print('train', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def validation_step(self, batch, batch_idx):
        print('val', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                               lr=.0001)
        return optimizer

litmod = LitModule()
trainer = pl.Trainer(
    overfit_batches=1,
    accelerator='cuda',
    max_epochs=20,
    check_val_every_n_epoch=10,
)
trainer.fit(litmod)
print(litmod)

The val batch is the [10,11] tensor, the train batch is the [0,1] tensor image


### Environment


### More info

_No response_

cc @justusschock @awaelchli
israfelsr commented 8 months ago

I had the same problem. I was going crazy because in the documentation they supposed to be the same 😅.