Closed BardiaKh closed 8 months ago
@BardiaKh Changing the batch size during an epoch is not supported by the Trainer. You can only do it from one epoch to the next by implementing your train_dataloader()
accordingly, returning a dataloader as a function of the self.current_epoch
for example, and setting Trainer(reload_dataloaders_every_n_epochs=1)
.
Note that reaching into the internals of the loop like you do right now is not advised and will lead to misbehavior.
@BardiaKh Could you check my reply whether it helps with your use case please?
Thanks for your help. I see that per-iteration validation does not reload the loop. However I need to eval on certain iterations and by changing the batch size, that does not necessarily mean at the epoch end.
To address these I add a sampler to change the number of iterations for each epoch.
Here is the code in case someone faces a similar use case:
def _reset_dataloaders(self):
loop = self.trainer.fit_loop
world_size = self.trainer.world_size
val_check_interval = int(self.config.validation.interval)
train_batch_size = int(self.batch_size)
sampler_num_samples = int(val_check_interval * train_batch_size)
self.train_sampler = torch.utils.data.RandomSampler(self.train_ds, replacement=True, num_samples=sampler_num_samples)
if world_size > 1:
self.ddpm_sampler = True
loop._combined_loader = None # force a reload
loop.setup_data()
Note that I wrap the sampler (in train_dataloader()
) with a distrusted sampler proxy when creating the data loader.
In this case it's not supported by Lightning as you said. Reloading the dataloaders at arbitrary points in the loop is probably too complex to introduce. For now I think what you did is probably the best.
If next time you know in advance that you need this granular of control, I think I would go with Lightning Fabric (where you can write the training loop yourself and handle such logic very easily).
Bug description
I am using PL to train a model that utilizes progressive resizing of the input. To have an efficient pipeline, I need to decrease the training batch_size as I increase the resolution to prevent OOM. Based on this #17327 issue, I use the following code in my
on_validation_epoch_start
hook:This code runs successfully. However, in the training step, when I use the following:
Here, I get the following:
To me this means that the
._reset_dataloaders()
andloop.setup_data()
work properly (as to the correct shape ofDL Batch shape
), however it seems that the old dataloadet is being used to fetch batches for thetraining_step
.Is there something that I am missing?
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- Lightning Component: Trainer, LightningModule #- PyTorch Lightning Version: 2.1.3 #- PyTorch Version: 2.1.3 #- Python version: 3.11 #- OS: Linux ```More info
No response