Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.42k stars 3.29k forks source link

calling iter twice messes up dataloaders with queues #19427

Open ben-da6 opened 4 months ago

ben-da6 commented 4 months ago

Bug description

This bug has reappeared https://github.com/Lightning-AI/pytorch-lightning/issues/18414

We now call iter() twice in different places:

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import multiprocessing as mp
from queue import Queue
from typing import Iterator

import numpy as np
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from torch.utils.data import DataLoader, IterableDataset

class QueueDataset(IterableDataset):
    def __init__(self, queue: Queue) -> None:
        super().__init__()
        self.queue = queue

    def __iter__(self) -> Iterator:
        for k in range(5):
            print(f"getting {k}")
            tensor, index = self.queue.get(timeout=10)
            print(f"got {index}")
            yield tensor

if __name__ == "__main__":
    q = mp.Queue()
    arr = np.random.random([1, 32]).astype(np.float32)
    for ind in range(10):
        q.put((arr, ind))
    max_epoch = 1
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=max_epoch, enable_progress_bar=False, devices=1)
    trainer.fit(BoringModel(), dataloader)
    trainer.save_checkpoint("model.ckpt")

    # q now has the next 5 elems in
    # resuming training we will hit the double iter() issue
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=max_epoch + 1, enable_progress_bar=False, devices=1)
    trainer.fit(BoringModel(), dataloader, ckpt_path="model.ckpt")

Error messages and logs

relevant logs are:

# first epoch all good
getting 0
got 0
getting 1
got 1
getting 2
got 2
getting 3
got 3
getting 4
got 4

# second epoch we start getting from the queue twice!
# from fit loop iter()
getting 0
got 5
getting 1
got 6
getting 2
got 7
# from training_epoch loop iter()
getting 0
got 8
getting 1
got 9
getting 2

Environment

lighting==2.1.4

More info

No response

cc @justusschock @awaelchli @carmocca

awaelchli commented 4 months ago

This condition here is meant to prevent the iter() from getting called a second time, because in this case restarting should be True.

https://github.com/Lightning-AI/pytorch-lightning/blob/47c8f4cba089a78fa3fe31dcac6a43416bc13820/src/lightning/pytorch/loops/training_epoch_loop.py#L169-L171

But it isn't. The problem is that the fit loop sets restarting=False even though we are resuming, due to the logic here:

https://github.com/Lightning-AI/pytorch-lightning/blob/47c8f4cba089a78fa3fe31dcac6a43416bc13820/src/lightning/pytorch/loops/fit_loop.py#L123-L128

This is tricky to solve @carmocca. The logic probably needs to be lifted up into the fit loop before epoch_loop.run(), with a different conditioning that does not rely on restarting.

carmocca commented 4 months ago

I didn't look too deeply. Couldn't we check restarting too for the FitLoop's iter call? We have a lot of tests around this so If a solution passes them we should be good.

ben-da6 commented 4 months ago

The problem in the restarting property is self._iteration_based_training() is False

ben-da6 commented 4 months ago

Also since this has appeared twice now, and its the sort of bug which is hard to track down could we add a test like my example?