Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.39k forks source link

`BatchSizeFinder` limits number of validation batches for the whole training process #18834

Closed BoringDonut closed 1 year ago

BoringDonut commented 1 year ago

Bug description

Using BatchSizeFinder seems to limit number of validation batches to BatchSizeFinder._steps_per_trial.

This results in val set being equal to few dozens samples and inadequate metrics being produced.

It seems it can be fixed by calling to _reset_dataloaders one additional time

What version are you seeing the problem on?

v1.8, v2.0

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer, LightningDataModule
from lightning.pytorch.callbacks import BatchSizeFinder
from lightning.pytorch.tuner.batch_size_scaling import _reset_dataloaders
from torch.utils.data import DataLoader, Dataset

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class Data(LightningDataModule):
    def __init__(self, ds_size: int, batch_size: int):
        super().__init__()
        self.ds_size = ds_size
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(RandomDataset(27, self.ds_size), batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(RandomDataset(27, self.ds_size), batch_size=self.batch_size)

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(27, 128)
        self.dropout = torch.nn.Dropout(0.5)
        self.val_sample_counter = 0

    def forward(self, x):
        return self.dropout(self.layer(x))

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)
        if not self.trainer.sanity_checking:
            self.val_sample_counter += len(batch)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_validation_epoch_start(self) -> None:
        self.val_sample_counter = 0

    def on_validation_epoch_end(self) -> None:
        if not self.trainer.sanity_checking:
            print(f"VALIDATED {self.val_sample_counter} SAMPLES ON EPOCH {self.trainer.current_epoch}")

class CustomBatchSizeFinder(BatchSizeFinder):
    def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        super().on_validation_start(trainer, pl_module)
        _reset_dataloaders(trainer)

def run():
    DATASET_SIZE = 123
    BATCH_SIZE = 2
    steps_per_trial = 3

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    print("Without BatchSizeFinder:")
    trainer.fit(model, datamodule=Data(DATASET_SIZE, BATCH_SIZE))
    print("-" * 20)

    model = BoringModel()
    callbacks = [BatchSizeFinder(steps_per_trial=steps_per_trial)]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=callbacks,
    )
    print("With BatchSizeFinder:")
    trainer.fit(model, datamodule=Data(DATASET_SIZE, BATCH_SIZE))
    print("-"*20)

    model = BoringModel()
    callbacks = [CustomBatchSizeFinder(steps_per_trial=steps_per_trial)]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=2,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=callbacks,
    )
    print("With BatchSizeFinder that calls to `_reset_dataloaders`:")
    trainer.fit(model, datamodule=Data(DATASET_SIZE, BATCH_SIZE))

if __name__ == "__main__":
    torch.manual_seed(1)
    run()

Error messages and logs

Here is log that shows a number of validated samples for each epoch. Val ds size: 123, num epochs: 2, batch size: 2 (see code above)

Without BatchSizeFinder:
VALIDATED 123 SAMPLES ON EPOCH 0
VALIDATED 123 SAMPLES ON EPOCH 1
--------------------
With BatchSizeFinder:
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 1
--------------------
With BatchSizeFinder that calls to `_reset_dataloaders`:
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 6 SAMPLES ON EPOCH 0
VALIDATED 123 SAMPLES ON EPOCH 0
VALIDATED 123 SAMPLES ON EPOCH 1

As you can see first and last runs validated all 123 samples twice, while second run (with default BatchSizeFinder) only validated 6 samples on both epochs. Here 6 = steps_per_trial * BATCH_SIZE = 3 * 2.

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 3050 Laptop GPU - available: True - version: 12.1 * Lightning: - lightning: 2.1.0 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.1.0 - torch: 2.1.0 - torchmetrics: 1.2.0 * Packages: - aiohttp: 3.8.6 - aiosignal: 1.3.1 - async-timeout: 4.0.3 - attrs: 23.1.0 - certifi: 2023.7.22 - charset-normalizer: 3.3.0 - filelock: 3.12.4 - frozenlist: 1.4.0 - fsspec: 2023.9.2 - idna: 3.4 - jinja2: 3.1.2 - lightning: 2.1.0 - lightning-utilities: 0.9.0 - markupsafe: 2.1.3 - mpmath: 1.3.0 - multidict: 6.0.4 - networkx: 3.1 - numpy: 1.24.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.18.1 - nvidia-nvjitlink-cu12: 12.2.140 - nvidia-nvtx-cu12: 12.1.105 - packaging: 23.2 - pip: 23.2.1 - pytorch-lightning: 2.1.0 - pyyaml: 6.0.1 - requests: 2.31.0 - setuptools: 68.1.2 - sympy: 1.12 - torch: 2.1.0 - torchmetrics: 1.2.0 - tqdm: 4.66.1 - triton: 2.1.0 - typing-extensions: 4.8.0 - urllib3: 2.0.7 - wheel: 0.41.2 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.18 - release: 5.15.0-83-generic - version: #92-Ubuntu SMP Mon Aug 14 09:30:42 UTC 2023

More info

@tanaymeh can you maybe add related fix to #18826 ? It seems to be related to the sample parts of code and only require a few additional lines.

awaelchli commented 1 year ago

I think this is a duplicate of https://github.com/Lightning-AI/lightning/issues/18394. Can you confirm?

BoringDonut commented 1 year ago

I think this is a duplicate of #18394. Can you confirm?

Yes, indeed. Sorry for that

BoringDonut commented 1 year ago

Dublicate and fixed with https://github.com/Lightning-AI/lightning/pull/18854