Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.25k stars 3.38k forks source link

Changing batch_size during training (setup_data() not working) #19286

Closed BardiaKh closed 8 months ago

BardiaKh commented 9 months ago

Bug description

I am using PL to train a model that utilizes progressive resizing of the input. To have an efficient pipeline, I need to decrease the training batch_size as I increase the resolution to prevent OOM. Based on this #17327 issue, I use the following code in my on_validation_epoch_start hook:

    def on_validation_epoch_start(self):
        global_step_proportion = (self.trainer.global_step+1) / self.trainer.max_steps
        training_cut = np.searchsorted(self.train_proportions.cumsum(), global_step_proportion) + 1
        training_cut = int(training_cut.item())

        self.batch_size = self.batch_sizes[training_cut - 1]
        self._reset_dataloaders()

def _reset_dataloaders(self):
        loop = self.trainer.fit_loop
        loop._combined_loader = None  # force a reload
        loop.setup_data()
        self.print("dataloader reset")

This code runs successfully. However, in the training step, when I use the following:

    def training_step(self, batch, batch_idx):
        img = batch["img"]
        cls = batch["cls"] if self.class_conditioned else None

        batch_size = img.shape[0]

        if len(self.current_cut) > 1:
            self.print("*"*50)
            self.print(self.trainer._active_loop)
            self.print(f"Actual Batch size: {batch_size}")
            self.print(f"Current Batch size:  {self.batch_size}")
            self.print("DL Batch shape:"next(iter(self.trainer._active_loop._combined_loader._iterables))['img'].shape)
            self.print("*"*50)

Here, I get the following:

**************************************************
<pytorch_lightning.loops.fit_loop._FitLoop object at 0x7fe1d9faf110>
Actual Batch size: 280
Current Batch size:  70
DL Batch shape: torch.Size([70, 1, 512, 512])
**************************************************

To me this means that the ._reset_dataloaders() and loop.setup_data() work properly (as to the correct shape of DL Batch shape), however it seems that the old dataloadet is being used to fetch batches for the training_step.

Is there something that I am missing?

What version are you seeing the problem on?

v2.1

How to reproduce the bug

pl.Trainer(
            gradient_clip_val=1.0,
            deterministic=True,
            callbacks=[checkpoint_callback1, lr_monitor, ema],
            profiler='simple',
            logger=wandb_logger,
            precision=self.precision,
            accelerator="gpu",
            devices=-1,
            num_nodes=1,
            strategy=DDPStrategy(find_unused_parameters=True),
            log_every_n_steps=10,
            default_root_dir=self.root_directory,
            num_sanity_val_steps=1,
            fast_dev_run=False,
            max_epochs=-1,
            max_steps=10000,
            use_distributed_sampler=False, # using DDP sampler in the code
            val_check_interval=10,
        )

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component: Trainer, LightningModule #- PyTorch Lightning Version: 2.1.3 #- PyTorch Version: 2.1.3 #- Python version: 3.11 #- OS: Linux ```

More info

No response

awaelchli commented 9 months ago

@BardiaKh Changing the batch size during an epoch is not supported by the Trainer. You can only do it from one epoch to the next by implementing your train_dataloader() accordingly, returning a dataloader as a function of the self.current_epoch for example, and setting Trainer(reload_dataloaders_every_n_epochs=1).

Note that reaching into the internals of the loop like you do right now is not advised and will lead to misbehavior.

awaelchli commented 9 months ago

@BardiaKh Could you check my reply whether it helps with your use case please?

BardiaKh commented 9 months ago

Thanks for your help. I see that per-iteration validation does not reload the loop. However I need to eval on certain iterations and by changing the batch size, that does not necessarily mean at the epoch end.

To address these I add a sampler to change the number of iterations for each epoch.

Here is the code in case someone faces a similar use case:


def _reset_dataloaders(self):
        loop = self.trainer.fit_loop
        world_size = self.trainer.world_size
        val_check_interval = int(self.config.validation.interval)
        train_batch_size = int(self.batch_size)
        sampler_num_samples = int(val_check_interval * train_batch_size)
        self.train_sampler = torch.utils.data.RandomSampler(self.train_ds, replacement=True, num_samples=sampler_num_samples)
        if world_size > 1:
            self.ddpm_sampler = True

        loop._combined_loader = None  # force a reload
        loop.setup_data()

Note that I wrap the sampler (in train_dataloader()) with a distrusted sampler proxy when creating the data loader.

awaelchli commented 8 months ago

In this case it's not supported by Lightning as you said. Reloading the dataloaders at arbitrary points in the loop is probably too complex to introduce. For now I think what you did is probably the best.

If next time you know in advance that you need this granular of control, I think I would go with Lightning Fabric (where you can write the training loop yourself and handle such logic very easily).