Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.51k stars 3.39k forks source link

Resume from mid steps inside an epoch #19764

Open xiaosuyu1997 opened 7 months ago

xiaosuyu1997 commented 7 months ago

Description & Motivation

LLMs are trained on growing size of corpora, only resume by epochs is not enough, as models may only be trained on a few epochs and one epoch may take a few days to train. Currently lightning prints a warning message as follows when trying to resume from mid steps inside an epoch and asks for a resumable dataloader: image

However, I can't find any examples resuming from mid steps in docs/blogs(maybe my bad). And it's quite strange to me to implement a dataloader with state_dict/load_state_dict methods, as dataloader cannot hold states by design, it's the iterator derived from dataloader that is resumable and should hold the necessary states. Besides, we may not need the state_dict and load_state_dict methods to save/load dataloaders, as the epoch/step idx hold enough message to restore the necessary training batch state.

I proposed a possible hackin that can work around this issue, taking inspirations from hugging face train script.

Pitch

No response

Alternatives

Here is an ugly hackin(by callbacks in LightningModule) now I used to resume the specific batch:

class SkipBatchSampler(BatchSampler):
    r"""
    Modified from huggingface accelerate/data_loader.py
    """
    def __init__(self, batch_sampler: BatchSampler, skip_batches: int = 0):
        self.batch_sampler = batch_sampler
        self.skip_batches = skip_batches

    def __iter__(self):
        for i, batch in enumerate(self.batch_sampler):
            if i >= self.skip_batches:
                yield batch

    def __len__(self):
        return len(self.batch_sampler)   # - self.skip_batches, due to in loops.training_epoch_loop.py on_run_start(), which will set fetched value, ugly hackin here

_PYTORCH_DATALOADER_KWARGS_SUBSTITUTE = {
    "num_workers": 0,
    "collate_fn": None,
    "pin_memory": False,
    "timeout": 0,
    "worker_init_fn": None,
    "multiprocessing_context": None,
    "generator": None,
    "prefetch_factor": 2,
    "persistent_workers": False,
}

def resume_dataloader(dataloader: DataLoader, steps_in_epoch: int) -> DataLoader:
    r"""
    We don't want to directly iterate on dataloader (which will cause data
    preprocessing overhead), we iterate on sampler
    """
    #TODO, currently not support iterable dataset, DataLoaderDispatcher, DataLoaderShard
    assert not isinstance(dataloader.dataset, IterableDataset)
    new_batch_sampler = SkipBatchSampler(dataloader.batch_sampler, steps_in_epoch)
    kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS_SUBSTITUTE[k])
                for k in _PYTORCH_DATALOADER_KWARGS_SUBSTITUTE}
    return DataLoader(dataloader.dataset, batch_sampler=new_batch_sampler, **kwargs)

class LightningModel(L.LightningModule):
    # hackins
    def on_train_start(self):
        self.restarted_run = False

    def on_train_epoch_start(self):
        # modify train dataloader
        if self.trainer.fit_loop.restarting:
            self.restarted_run = True
            self.trainer.fit_loop.backup_dataloaders = self.trainer.fit_loop._combined_loader.flattened
            self.trainer.fit_loop._combined_loader.flattened = [
                resume_dataloader(dl, self.trainer.fit_loop.epoch_loop.batch_progress.current.completed)
                for dl in self.trainer.fit_loop._combined_loader.flattened
            ]
            # need to call iter to rebuild data_fetcher.iterator (which is originally
            # set in setup_data)
            self.trainer.fit_loop._data_fetcher.setup(self.trainer.fit_loop._combined_loader)
            with isolate_rng():
                iter(self.trainer.fit_loop._data_fetcher)
        else:
            if self.restarted_run:
                self.trainer.fit_loop._combined_loader.flattened = self.trainer.fit_loop.backup_dataloaders
                # set epoch again, cause the epoch right after restarting one will have problems
                for dl in self.trainer.fit_loop._combined_loader.flattened:
                    _set_sampler_epoch(dl, self.trainer.current_epoch)
                self.trainer.fit_loop._data_fetcher.setup(self.trainer.fit_loop._combined_loader)
                # no need to rebuild iterator, already in epoch_loop.on_run_start
                # iter(self.trainer.fit_loop._data_fetcher)

Additional context

No response

cc @borda

Turakar commented 1 month ago

It seems like StatefulDataLoader from torchdata might help here. However, if I replace my old data loader with StatefulDataLoader, I cannot find a corresponding entry in the saved checkpoint. The warning doesn't appear, either.

xk-huang commented 5 days ago

I am experiencing the same problem when resuming training with a huge data scale in one epoch. I would agree to support the skipping batch logic as in hugging face train script.