Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
307 stars 36 forks source link

Resuming Training w/ Streaming Dataset on DDP with Multiple Nodes Fails #248

Closed schopra8 closed 1 month ago

schopra8 commented 1 month ago

🐛 Bug

We trained a model for several epochs on multiple nodes, and we wanted to continue training with PyTorch Lightning and LitData.

✅ When we resume training on a single device, resumption works as expected. ✅ When we resume training on a single node with N devices, resumption works as expected. ❌ When we resume training on multiple nodes with N devices, resumption fails.

To Reproduce

Run trainer.fit with an existing checkpoint with DDP on multiple devices:

StackTrace:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/train.py", line 70, in <module>
[rank7]:     main(config)
[rank7]:   File "/home/train.py", line 50, in main
[rank7]:     trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
Failures:
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank7]:     batch = super().__next__()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank7]:     batch = next(self.iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank7]:     out = next(self._iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank7]:     out[i] = next(self.iterators[i])
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 628, in __iter__
[rank7]:     for batch in super().__iter__():
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank7]:     data = self._next_data()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1326, in _next_data
[rank7]:     return self._process_data(data)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank7]:     data.reraise()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank7]:     raise exception
[rank7]: IndexError: Caught IndexError in DataLoader worker process 1.
[rank7]: Original Traceback (most recent call last):
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank7]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank7]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank7]:     self.dataset_iter = iter(dataset)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 144, in __iter__
[rank7]:     self._iterator = _CombinedDatasetIterator(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in __init__
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in <listcomp>
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 211, in __iter__
[rank7]:     self._resume(chunks_replica, intervals_replica)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 272, in _resume
[rank7]:     interval = self.worker_intervals[self.chunk_index]
[rank7]: IndexError: list index out of range

Code sample

I've scrubbed my code below --

# Create dataset module
custom_data_module = CustomDataModule(config)

# Initialize the model
model = Model()

# Define the PyTorch Lightning Trainer
wandb_logger = WandbLogger(**config.wandb_logger)
device_stats_monitor = DeviceStatsMonitor()
strategy = DDPStrategy(find_unused_parameters=True)
trainer = Trainer(
    logger=wandb_logger,
    callbacks=[device_stats_monitor],
    strategy=strategy,
    max_epochs=4,
    val_check_interval=500,
    accelerator='gpu',
    devices=8,
    num_nodes=2,
    enable_progress_bar=True,
    log_every_n_steps=50,
    precision=32,
    default_root_dir='/scratch/lightning_logs'
)
trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
class CustomDataModule(LightningDataModule):
    """
    Custom Data Module wraps training/validation StreamingDataset objects.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

    def setup(self, stage=None):
        if stage in (None, 'fit'):
            # Create train datasets
            self.train_datasets = []
            for ds_config in self.config.datasets.train:
                dataset = StreamingDataset(
                    input_dir=ds_config.path,
                    subsample=ds_config.subsample
                )
                self.train_datasets.append(dataset)
            self.train_dataset = CombinedStreamingDataset(self.train_datasets)

            # Create validation datasets
            self.val_datasets = []
            for ds_config in self.config.datasets.val:
                dataset = StreamingDataset(
                    input_dir=ds_config.path,
                    subsample=ds_config.subsample
                )
                self.val_datasets.append(dataset)
            self.val_dataset = CombinedStreamingDataset(self.val_datasets)

    def train_dataloader(self):
        return StreamingDataLoader(
            self.train_dataset,
            collate_fn=collate_fn,
            **self.config.dataloader
        )

    def val_dataloader(self):
        return StreamingDataLoader(
            self.val_dataset,
            collate_fn=collate_fn,
            **self.config.dataloader
        )

Expected behavior

Resume training on multiple nodes

Environment

github-actions[bot] commented 1 month ago

Hi! thanks for your contribution!, great first issue!

tchaton commented 1 month ago

Hey @schopra8,

Thanks for reporting the issue. Yes, we are aware of that :) so we are already looking into it. If we do fix it next week, we would make a new release cc @awaelchli

Unfortunately, this is coming as a series of fixes (https://github.com/Lightning-AI/litdata/pull/237) that aren't backward compatible (we won't be able to load old checkpoints as the core logic has changed too much).

schopra8 commented 1 month ago

np! thanks for the heads up

schopra8 commented 1 month ago

Also wanted to flag --

I tried running resume on 1 node with N devices. Everything worked for the first couple hundred steps, but then I hit the same error. So, it looks like there is a similar issue in DDP on 1 Node, as well.

tchaton commented 1 month ago

Hey @schopra8 Here are the release notes: https://github.com/Lightning-AI/litdata/releases/tag/v0.2.17.

Would you mind trying again with the latest version: 0.2.17 ?

Old checkpoints won't work unfortunately.

schopra8 commented 1 month ago

Awesome! I'll try in the next 1-2 days and report back my results

tchaton commented 1 month ago

Thanks @schopra8.

schopra8 commented 1 month ago

@tchaton Tested this and works! Closing this issue - since the DDP problem is solved.

We're having another issue https://github.com/Lightning-AI/litdata/issues/263 with resuming training with a new dataset. We want to preserve optimizer states, etc. when we continue training. Any guidance would be much appreciated!