Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.47k stars 3.39k forks source link

IterableDataset with CORRECT length causes validation loop to be skipped #19624

Open mattcleigh opened 8 months ago

mattcleigh commented 8 months ago

Bug description

This is related to this issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10290

Whereby an IterableDataset with a length defined wont trigger a validation epoch, even if the defined length is correct so long as the following conditions met:

  1. Accurate length of IterableDataset defined
  2. Dataset accurately split between multiple workers with no overlap
  3. Drop last = True for the dataloader
  4. Dataset size does not evenly divide into the batches

In this instance multiple workers may be left with an incomplete batch right at the end of the training epoch. So the number of "dropped batches" exceeds 1. Then the dataloader will raise a StopIteration before the length is reached, causing the validation epoch to be skipped.

This is standard PyTorch behavior as the collation function is called per worker in an IterableDataset. https://github.com/pytorch/pytorch/issues/33413

I am having this issue right now my current fix is artificially subtract from the length of my IterableDataset to account for this. Unfortunately I really would like the length to be defined, so can't set it to inf which was the hotfix in the previous thread. The progress bar is useful for me to judge which partition I need to run certain jobs on plus I use the dataset length to sync up my cyclic learning rate with the number of steps in an epoch.

What version are you seeing the problem on?

master

How to reproduce the bug

import torch as T
import numpy as np
from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info

from lightning import LightningModule, Trainer, LightningDataModule

nwrkrs = 4
drop = True

class Data(IterableDataset):
    def __init__(self) -> None:
        super().__init__()
        self.data = np.random.rand(100, 10).astype(np.float32)

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        worker_info = get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        num_workers = 1 if worker_info is None else worker_info.num_workers
        worker_samples = np.array_split(self.data, num_workers)[worker_id]

        for i in worker_samples:
            yield i

class Model(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.layer = T.nn.Linear(10, 1)
        self.did_validation = False

    def forward(self, x: T.Tensor) -> T.Tensor:
        return self.layer(x)

    def training_step(self, batch):
        return self(batch).mean()

    def validation_step(self, batch):
        self.did_validation = True
        return self(batch).mean()

    def configure_optimizers(self):
        return T.optim.Adam(self.parameters())

model = Model()
trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
train_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
valid_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
trainer.fit(model, train_loader, valid_loader)
print("Performed validation:", model.did_validation)

Setting up the code above and running it with the following settings gives these results:

nwrkrs = 0, drop = True

Performed validation: True

nwrkrs = 4, drop = False

Performed validation: True

nwrkrs = 4, drop = True

Performed validation: False

cc @justusschock @awaelchli

awaelchli commented 8 months ago

@mattcleigh What should Lightning do here?

You have explained in your own words why this happens in the PyTorch DataLoader, you even pointed to the PyTorch GitHub issue where this is explained and acknowledged.

Here is again your code but with Lightning stripped away:

import numpy as np
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

class Data(IterableDataset):
    def __init__(self) -> None:
        super().__init__()
        self.data = np.random.rand(100, 10).astype(np.float32)

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        worker_info = get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        num_workers = 1 if worker_info is None else worker_info.num_workers
        worker_samples = np.array_split(self.data, num_workers)[worker_id]

        for i in worker_samples:
            yield i

if __name__ == "__main__":
    data = Data()
    print(f"{len(data)=}")

    len1 = len(list(DataLoader(Data(), num_workers=4, batch_size=32, drop_last=True)))
    len2 = len(list(DataLoader(Data(), num_workers=4, batch_size=32, drop_last=False)))
    len3 = len(list(DataLoader(Data(), num_workers=2, batch_size=32, drop_last=False)))

    print(f"{len1=}, {len2=}, {len3=}")

This prints

len(data)=100
len1=0, len2=4, len3=4

As you can see, with a dataset size of 100, and a batch size of 32 and 4 workers, if you drop_last=True then the dataloader length will be zero. And you have explained why that is in your own words. I'll break it down again:

If you divide 100 by 4, you get 25. That's less than 32. So you don't get a full batch in each worker, and so drop_last=True will drop the remainder from every worker, resulting in an empty dataloader.

An empty dataloader will lead to a skip in the validation loop, because there is nothing to loop over.

Lightning builds on top of PyTorch, therefore it is subject to the design choices in PyTorch. If users take a dataloader to Lightning, I believe it should work the same way as in raw PyTorch. Please let me know if I missed something here.

mattcleigh commented 8 months ago

Hi @awaelchli,

Thanks for the response. You are right that the code I sent didn't quite show the bug. That was my mistake and was due to a typo. Instead change the size of the dataset to 1000: self.data = np.random.rand(1000, 10).astype(np.float32).

This properly highlights what my issue is. When running the following code:


import torch as T
import numpy as np
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

from lightning import LightningModule, Trainer

nwrkrs = 4
drop = True

class Data(IterableDataset):
    def __init__(self) -> None:
        super().__init__()
        self.data = np.random.rand(1000, 10).astype(np.float32)

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        worker_info = get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        num_workers = 1 if worker_info is None else worker_info.num_workers
        worker_samples = np.array_split(self.data, num_workers)[worker_id]

        for i in worker_samples:
            yield i

class Model(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.layer = T.nn.Linear(10, 1)
        self.did_validation = False

    def forward(self, x: T.Tensor) -> T.Tensor:
        return self.layer(x)

    def training_step(self, batch):
        return self(batch).mean()

    def validation_step(self, batch):
        self.did_validation = True
        return self(batch).mean()

    def configure_optimizers(self):
        return T.optim.Adam(self.parameters())

model = Model()
trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
train_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
valid_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
print(len(list(train_loader)))
print(len(list(valid_loader)))
trainer.fit(model, train_loader, valid_loader)
print("Performed validation:", model.did_validation)

My output is

Epoch 0:  90%|█████   | 28/31 [00:00<00:00, 105.05it/s]
Epoch 1:  90%|█████   | 28/31 [00:00<00:00, 101.66it/s]
`Trainer.fit` stopped: `max_epochs=2` reached.
Train loader size  28
Valid loader size  28
Performed validation: False

As you can see the issue is not that the validation epoch is not being run because the datasets are empty. But rather it is not being run because more than a single batch was dropped during the training epoch. One per worker.

This is by Pytorch's design, its intended behavior, and it can't be changed. But what is a bug in Lightning is that it will never call the validation epoch under these settings despite plenty of data.

If however we do not define the length of the dataset, then Lightning will trigger the validation epoch upon encountering a StopIteration. Here is what I get when I comment out the __len__ method:

Epoch 0: || 28/? [00:00<00:00, 58.74it/s]
Epoch 1: || 28/? [00:00<00:00, 58.74it/s]
Train loader size  28
Valid loader size  28
Performed validation: True

Exactly the intended behavior. But now we loose the functionality that comes with defining a length in our dataset. Notably the progress bar and as I mentioned earlier I use the length of my dataloaders to define properties of my scheduler. Even if the actual length does turn out to be a couple batches shorter.

On the other hand, we will also get a validation epoch called even with the length defined so long as the batch size cleanly divides the dataset length, so nothing is dropped. Setting a batch size of 25 gives me this output.

Epoch 1: 100%|███████████████| 40/40 [00:00<00:00, 59.40it/s]
Epoch 1: 100%|███████████████| 40/40 [00:00<00:00, 58.43it/s]
`Trainer.fit` stopped: `max_epochs=2` reached.
Train loader size  40
Valid loader size  40
Performed validation: True

So now we run into an instance where Lightning may or may not call the validation epoch depending on the batch size. If this is still not a bug it is very obscure and non-intuitive behavior.

Essentially this boils down to how lightning triggers the validation epoch:

This last point is what I think should change. I think that even with the length defined the validation epoch should still trigger based on the batch index OR on the StopIteration, at least under the default behaviour with check_val_every_n_epoch=1.

Hope this clears things up.

awaelchli commented 8 months ago

Do you agree that the problem here is that

len(train_loader) != len(list(train_loader))

?

Since Lightning can't iterate through the dataloader first one time to find it's real length, only a heuristic could probably help with this edge case. Under the special case, one could define something like that:

def length(loader):
    if (
        isinstance(loader, DataLoader)  # probably type(loader) == DataLoader here to be strict
        and isinstance(loader.dataset, IterableDataset)
        and hasattr(loader.dataset, "__len__")
        and loader.num_workers > 0
        and loader.batch_size is not None
        and loader.drop_last
    ):
        return len(loader.dataset) // loader.batch_size  # minus some number here

But I'm wary of doing this because it deviates from PyTorch's definition here: https://github.com/pytorch/pytorch/blob/3e02a7efcdd13f8d8b1ecbb64639fb694988d11f/torch/utils/data/dataloader.py#L479

I am very afraid here of breaking anybody's code, I hope you understand that this is very subtle and the core issue lies in PyTorch and the IterableDataset. If you were to implement a training loop around such a dataloader, how would you do it?

mattcleigh commented 8 months ago

I believe the issue is that Lightning is over relying on the length parameter to determine the end of the training epoch.

If I were to write my own training loop I would simply do

for epoch in range(epochs):
    for batch in train_loader():
        model.train_step(batch)
    for batch in valid_loader():
        model.valid_step(batch)

Obviously I am not trying to account for every use-case Lightning supports. But as we have seen, lightning does indeed use something like this when the length is infinite. So why could it not do both?

The logic to define this check is here:

https://github.com/Lightning-AI/pytorch-lightning/blob/6e517bd55b50166138ce6ab915abd4547702994b/src/lightning/pytorch/loops/training_epoch_loop.py#L415C1-L419C24

        # val_check_batch is inf for iterable datasets with no length defined
        is_infinite_dataset = self.trainer.val_check_batch == float("inf")
        is_last_batch = self.batch_progress.is_last_batch
        if is_last_batch and (is_infinite_dataset or isinstance(data_fetcher, _DataLoaderIterDataFetcher)):
            return True

As you can see, the logic to trigger the validation epoch which uses the is_last_batch variable (which is internally triggered due to the StopIteration from the dataloader) is only considered if the dataloader hase infinite or no length defined, or if the user is passing the entire loader directly into the training loop. Otherwise the is_last_batch is completely ignored, this is what I would tweak.

Since using multiple workers is pretty much a standard for dataloading, as is dropping the last batch, I would argue that this isn't that much of an edge case. As anyone with an IterableDataset with a defined length is going to run into this problem.

At the very least, having your code completely break silently due to a batch size change is I think obscure enough that it should warrant some kind of heads up. Jobs would crash because my early stopping checkpoint would fail as the validation epoch never ran and it certainly cost our group some time trying to chase down the reason.

awaelchli commented 8 months ago

One could try to remove the condition

if is_last_batch:  # and (is_infinite_dataset or isinstance(data_fetcher, _DataLoaderIterDataFetcher)):

and see if any tests fail. Do you want to do it? The progress bar will always be wrong since len(train_loader) is not reported correctly, or is there a way other than what I have said above?

mattcleigh commented 8 months ago

Yeah I could give it a go. The only think that I think I need to account for is preventing the double execution of the valid epoch if the batch sizes line up.

Ill work on the logic and make a MR if I have the time after work.

awaelchli commented 8 months ago

I investigated this more and tried several ways to make this work, but it is just not possible right now to consolidate this with all other features and requirements in the Trainer. Just commenting the condition as we discussed won't be enough.

The biggest blocker is the following: https://github.com/Lightning-AI/pytorch-lightning/blob/fadd2fccdc49e20d64db37fe3654116b4f1b9e49/src/lightning/pytorch/loops/fetchers.py#L105-L107

The requirement in the Trainer is that it needs to know in advance whether a StopIteration is going to happen or not (whether we are at the last batch or not). Do to this, it will use the length if it is available, and otherwise use prefetching. Here, the length is assumed to be correct. But as we discussed, in the iterable dataset case the length may not coincide with the StopIteration. Therefore, in order to support the "broken" definition of DataLoader's length, we would have to enable prefetching on all iterable-style datasets. This is something we don't want do do, because it caused problems in the past (discussed with @carmocca).

For these reasons, we suggest not moving forward with this.

Going back to your use case @mattcleigh, can you explain why you are required to use drop_last in your dataloader? In my opinion, it should be discouraged to do that by PyTorch due to the way the batches are fetched with multiple workers.

For Lightning, that probably means we need to emit a loud warning when the combination iterable dataset, num_workers > 1, and drop_last=True occurs.

hungphongtrn commented 1 month ago

Hi folks, I am also facing a similar issue in which using CombinedLoader to combine 2 dataloaders with different lengths resulted in the validation_step being skipped during trainer.fit. However, trainer.validate and early validation check with num_sanity_val_steps=-1 work totally normal.

Here are some code snippets.

Additional question is that even if I set drop_last=False, the last batch is still being dropped. What could be the reason for this? Looking forward to your answer. Many thanks!

piotrb5e3 commented 1 month ago

This issue had me waste a good few hours of work trying to understand why my validation loop did not run.

Given that the current status is

For these reasons, we suggest not moving forward with this.

could we add a UserWarning about using this particular configuration? (num_workers > 0, iterable dataset, drop_last=True)? I'd like to contribute, but I'm not sure which place in code would be best for this?

DataLoader is a torch class, and should not know about the issue in lightning. I was looking at trainer.fit, but then there's separate code paths for using a DataLoader and LightningDataModule.

What do you think?