Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.32k stars 3.38k forks source link

Allow shuffling when overfit_batches is active #9850

Open low5545 opened 3 years ago

low5545 commented 3 years ago

Proposed refactoring or deprecation

Instead of disabling shuffle / replacing RandomSampler with SequentialSampler in the train dataloader, replace the train dataset with a fixed subset of it using torch.utils.data.Subset (eg. first N samples of the dataset, where N is given by overfit_batches. This gives the same dataset samples as with the previous implementation.)

Motivation

This prevents training batches to be the same for every epoch

Pitch

Added on 12 Oct 2021: The current implementation for overfit_batches disables shuffling by replacing RandomSampler with SequentialSampler in the train dataloader, in order to restrict the training / overfit to the first N samples of the train dataset for every epoch. However, this gives the same sequence of batches & non-unique batches across epochs, which is undesirable.

We should instead allow shuffling within the N samples across epochs, according to the shuffle option of the train dataloader, in order to give a different sequence of batches across epochs & mostly unique batches throughout the training process.


If you enjoy Lightning, check out our other projects! ⚡

cc @borda @justusschock @awaelchli @akihironitta @rohitgr7

rohitgr7 commented 3 years ago

N here is the number of batches, not the samples. so I think while using subset it won't be possible to set the exact indices required.

low5545 commented 3 years ago

More specifically, what I meant there was that N can be derived from overfit_batches by:

if isinstance(overfit_batches, float):
    N = int(overfit_batches * len(train_dataloader.dataset))
elif isinstance(overfit_batches, int):
    N = overfit_batches * effective_train_batch_size

Then, we can set train_dataloader.dataset = torch.utils.data.Subset(train_dataloader.dataset, torch.arange(N))

rohitgr7 commented 3 years ago

it would actually be:

if isinstance(overfit_batches, float): N = int(overfit_batches len(train_dataloader)) effective_train_batch_size

will it work with IterableDatasets?

low5545 commented 3 years ago
  1. According to the documentation for overfit_batches, a float value indicates a ratio of the training dataset to overfit on, and an int value indicates the number of batches to overfit on, which I presume is the number of batches per device. With that, I think the brief implementation that I provided is valid.
  2. Shuffling and samplers are irrelevant to IterableDataset. We can replace the training dataset with a subset only in the case of Dataset (map-style datasets).
rohitgr7 commented 3 years ago

oh my bad.. it's actually len(dataloader)*overfit_batches how it's handled internally for now. So in case if we use the new method you proposed then the batches might vary a bit from current implementation. For eg if total dataset size=100, batch_size=32 and overfit_batches=0.25 then the first batch will be of size 32 currently but with the update it will be 25.

let's see what others think cc @PyTorchLightning/core-contributors

low5545 commented 3 years ago

Ah, now I understand what you were referring to in the previous comment. I personally don't have any preference on how N is determined when it is a float, but the main motivation of this issue is to prevent identical batches across epochs.

awaelchli commented 3 years ago

Not sure I follow the discussion here completely, but:

This prevents training batches to be the same for every epoch

We want that the batches are exactly the same sequence every epoch. The goal of the overfit_batches feature is to check that the model can overfit. There wouldn't be a guarantee for this if the data changes every epoch.

rohitgr7 commented 3 years ago

as per my understanding, the only change will be the order of batches will become random, but they will be same in all the epochs. @awaelchli

low5545 commented 3 years ago

From what I proposed above, not only that the same subset of training dataset is retained for every epoch, it also allows for shuffling of samples within the subset across epochs, which gives a different sequence of batches for every epoch. Furthermore, most, if not all, of the batches are unique throughout the training. The current implementation gives the same sequence of batches & non-unique batches across epochs.

awaelchli commented 3 years ago

@low5545 ok, could you convert the issue to a proposal by changing the title and adding a clear pitch of what the behavior will be with your changes. thanks!

low5545 commented 3 years ago

@awaelchli Done!

carmocca commented 3 years ago

@low5545 why do you think this change is important? This flag is generally only used for quick debugging where this initial shuffling should not matter to find whether your model learns at all.

low5545 commented 3 years ago

@carmocca Well, my use case is that I am relying on overfitting on possibly a smaller subset of the training dataset to tune a large number of hyperparameters, specifically the weights of various loss terms and others, to identify which combinations are near optimal to achieve the final objective, which cannot be directly optimized with. So, it's not really just for the purpose of debugging.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

carmocca commented 2 years ago

The motivations states

This prevents training batches to be the same for every epoch

In that case, I don't think this fits well for overfit_batches.

Perhaps you are more interested in limit_train_batches?

low5545 commented 2 years ago

However, limit_train_batches alone does not achieve overfitting because the validation & test dataloaders are not replaced with the train dataloader.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!