Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
307 stars 36 forks source link

Resuming Training w/ Streaming Dataset on DDP with Multiple Nodes Fails #248

Closed schopra8 closed 1 month ago

schopra8 commented 1 month ago

🐛 Bug

We trained a model for several epochs on multiple nodes, and we wanted to continue training with PyTorch Lightning and LitData.

✅ When we resume training on a single device, resumption works as expected. ✅ When we resume training on a single node with N devices, resumption works as expected. ❌ When we resume training on multiple nodes with N devices, resumption fails.

To Reproduce

Run with an existing checkpoint with DDP on multiple devices:


[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/", line 70, in <module>
[rank7]:     main(config)
[rank7]:   File "/home/", line 50, in main
[rank7]:, datamodule=custom_data_module, ckpt_path=ckpt)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 543, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 43, in _call_and_handle_interrupt
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/", line 133, in __next__
[rank7]:     batch = super().__next__()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/", line 60, in __next__
[rank7]:     batch = next(self.iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/", line 341, in __next__
[rank7]:     out = next(self._iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/", line 78, in __next__
[rank7]:     out[i] = next(self.iterators[i])
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/", line 628, in __iter__
[rank7]:     for batch in super().__iter__():
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/", line 631, in __next__
[rank7]:     data = self._next_data()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/", line 1326, in _next_data
[rank7]:     return self._process_data(data)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/", line 1372, in _process_data
[rank7]:     data.reraise()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/", line 705, in reraise
[rank7]:     raise exception
[rank7]: IndexError: Caught IndexError in DataLoader worker process 1.
[rank7]: Original Traceback (most recent call last):
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/", line 252, in _worker_loop
[rank7]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/", line 79, in create_fetcher
[rank7]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/", line 21, in __init__
[rank7]:     self.dataset_iter = iter(dataset)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/", line 144, in __iter__
[rank7]:     self._iterator = _CombinedDatasetIterator(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/", line 192, in __init__
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/", line 192, in <listcomp>
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/", line 211, in __iter__
[rank7]:     self._resume(chunks_replica, intervals_replica)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/", line 272, in _resume
[rank7]:     interval = self.worker_intervals[self.chunk_index]
[rank7]: IndexError: list index out of range

Code sample

I've scrubbed my code below --

# Create dataset module
custom_data_module = CustomDataModule(config)

# Initialize the model
model = Model()

# Define the PyTorch Lightning Trainer
wandb_logger = WandbLogger(**config.wandb_logger)
device_stats_monitor = DeviceStatsMonitor()
strategy = DDPStrategy(find_unused_parameters=True)
trainer = Trainer(
), datamodule=custom_data_module, ckpt_path=ckpt)
class CustomDataModule(LightningDataModule):
    Custom Data Module wraps training/validation StreamingDataset objects.

    def __init__(self, config):
        self.config = config

    def setup(self, stage=None):
        if stage in (None, 'fit'):
            # Create train datasets
            self.train_datasets = []
            for ds_config in self.config.datasets.train:
                dataset = StreamingDataset(
            self.train_dataset = CombinedStreamingDataset(self.train_datasets)

            # Create validation datasets
            self.val_datasets = []
            for ds_config in self.config.datasets.val:
                dataset = StreamingDataset(
            self.val_dataset = CombinedStreamingDataset(self.val_datasets)

    def train_dataloader(self):
        return StreamingDataLoader(

    def val_dataloader(self):
        return StreamingDataLoader(

Expected behavior

Resume training on multiple nodes


github-actions[bot] commented 1 month ago

Hi! thanks for your contribution!, great first issue!

tchaton commented 1 month ago

Hey @schopra8,

Thanks for reporting the issue. Yes, we are aware of that :) so we are already looking into it. If we do fix it next week, we would make a new release cc @awaelchli

Unfortunately, this is coming as a series of fixes ( that aren't backward compatible (we won't be able to load old checkpoints as the core logic has changed too much).

schopra8 commented 1 month ago

np! thanks for the heads up

schopra8 commented 1 month ago

Also wanted to flag --

I tried running resume on 1 node with N devices. Everything worked for the first couple hundred steps, but then I hit the same error. So, it looks like there is a similar issue in DDP on 1 Node, as well.

tchaton commented 1 month ago

Hey @schopra8 Here are the release notes:

Would you mind trying again with the latest version: 0.2.17 ?

Old checkpoints won't work unfortunately.

schopra8 commented 1 month ago

Awesome! I'll try in the next 1-2 days and report back my results

tchaton commented 1 month ago

Thanks @schopra8.

schopra8 commented 1 month ago

@tchaton Tested this and works! Closing this issue - since the DDP problem is solved.

We're having another issue with resuming training with a new dataset. We want to preserve optimizer states, etc. when we continue training. Any guidance would be much appreciated!