Open low5545 opened 3 years ago
N here is the number of batches, not the samples. so I think while using subset it won't be possible to set the exact indices required.
More specifically, what I meant there was that N can be derived from overfit_batches
by:
if isinstance(overfit_batches, float):
N = int(overfit_batches * len(train_dataloader.dataset))
elif isinstance(overfit_batches, int):
N = overfit_batches * effective_train_batch_size
Then, we can set train_dataloader.dataset = torch.utils.data.Subset(train_dataloader.dataset, torch.arange(N))
it would actually be:
if isinstance(overfit_batches, float):
N = int(overfit_batches len(train_dataloader)) effective_train_batch_size
will it work with IterableDatasets?
overfit_batches
, a float
value indicates a ratio of the training dataset to overfit on, and an int
value indicates the number of batches to overfit on, which I presume is the number of batches per device. With that, I think the brief implementation that I provided is valid.IterableDataset
. We can replace the training dataset with a subset only in the case of Dataset
(map-style datasets).oh my bad.. it's actually len(dataloader)*overfit_batches
how it's handled internally for now. So in case if we use the new method you proposed then the batches might vary a bit from current implementation. For eg if total dataset size=100
, batch_size=32
and overfit_batches=0.25
then the first batch will be of size 32 currently but with the update it will be 25.
let's see what others think cc @PyTorchLightning/core-contributors
Ah, now I understand what you were referring to in the previous comment. I personally don't have any preference on how N is determined when it is a float
, but the main motivation of this issue is to prevent identical batches across epochs.
Not sure I follow the discussion here completely, but:
This prevents training batches to be the same for every epoch
We want that the batches are exactly the same sequence every epoch. The goal of the overfit_batches
feature is to check that the model can overfit. There wouldn't be a guarantee for this if the data changes every epoch.
as per my understanding, the only change will be the order of batches will become random, but they will be same in all the epochs. @awaelchli
From what I proposed above, not only that the same subset of training dataset is retained for every epoch, it also allows for shuffling of samples within the subset across epochs, which gives a different sequence of batches for every epoch. Furthermore, most, if not all, of the batches are unique throughout the training. The current implementation gives the same sequence of batches & non-unique batches across epochs.
@low5545 ok, could you convert the issue to a proposal by changing the title and adding a clear pitch of what the behavior will be with your changes. thanks!
@awaelchli Done!
@low5545 why do you think this change is important? This flag is generally only used for quick debugging where this initial shuffling should not matter to find whether your model learns at all.
@carmocca Well, my use case is that I am relying on overfitting on possibly a smaller subset of the training dataset to tune a large number of hyperparameters, specifically the weights of various loss terms and others, to identify which combinations are near optimal to achieve the final objective, which cannot be directly optimized with. So, it's not really just for the purpose of debugging.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
The motivations states
This prevents training batches to be the same for every epoch
In that case, I don't think this fits well for overfit_batches
.
Perhaps you are more interested in limit_train_batches
?
However, limit_train_batches
alone does not achieve overfitting because the validation & test dataloaders are not replaced with the train dataloader.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Proposed refactoring or deprecation
Instead of disabling shuffle / replacing
RandomSampler
withSequentialSampler
in the train dataloader, replace the train dataset with a fixed subset of it usingtorch.utils.data.Subset
(eg. first N samples of the dataset, where N is given byoverfit_batches
. This gives the same dataset samples as with the previous implementation.)Motivation
This prevents training batches to be the same for every epoch
Pitch
Added on 12 Oct 2021: The current implementation for
overfit_batches
disables shuffling by replacingRandomSampler
withSequentialSampler
in the train dataloader, in order to restrict the training / overfit to the first N samples of the train dataset for every epoch. However, this gives the same sequence of batches & non-unique batches across epochs, which is undesirable.We should instead allow shuffling within the N samples across epochs, according to the
shuffle
option of the train dataloader, in order to give a different sequence of batches across epochs & mostly unique batches throughout the training process.If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @borda @justusschock @awaelchli @akihironitta @rohitgr7