huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.16k stars 2.67k forks source link

Save and resume the state of a DataLoader #5454

Open lhoestq opened 1 year ago

lhoestq commented 1 year ago

It would be nice when using datasets with a PyTorch DataLoader to be able to resume a training from a DataLoader state (e.g. to resume a training that crashed)

What I have in mind (but lmk if you have other ideas or comments):

For map-style datasets, this requires to have a PyTorch Sampler state that can be saved and reloaded per node and worker.

For iterable datasets, this requires to save the state of the dataset iterator, which includes:

Right now you can already resume the data loading of an iterable dataset by using IterableDataset.skip but it takes a lot of time because it re-iterates on all the past data until it reaches the resuming point.

cc @stas00 @sgugger

thomasw21 commented 1 year ago

Something that'd be nice to have is "manual update of state". One of the learning from training LLMs is the ability to skip some batches whenever we notice huge spike might be handy.

stas00 commented 1 year ago

Your outline spec is very sound and clear, @lhoestq - thank you!

@thomasw21, indeed that would be a wonderful extra feature. In Megatron-Deepspeed we manually drained the dataloader for the range we wanted. I wasn't very satisfied with the way we did it, since its behavior would change if you were to do multiple range skips. I think it should remember all the ranges it skipped and not just skip the last range - since otherwise the data is inconsistent (but we probably should discuss this in a separate issue not to derail this much bigger one).

yqy2001 commented 8 months ago

Hi there! I think this is a critical issue and have an urgent need for it, in my attempt to train on a super large-scale dataset using datasets. It is impossible to resume a time-consuming (like one month) experiment by iterating all seen data again, which could possibly cost several days.

@stas00 @thomasw21 @lhoestq Any updates on this problem after 1 year passed?

dancingpipi commented 8 months ago

any update?

lhoestq commented 8 months ago

No update so far, I wonder if someone implemented a resumable pytorch Sampler somwhere.

Then regarding resuming a streaming dataset, we'd first like to have an efficient way to skip shards automatically but this is not implemented yet

lhoestq commented 8 months ago

I opened a draft here for IterableDataset: https://github.com/huggingface/datasets/pull/6658

"""Requires https://github.com/huggingface/datasets/pull/6658 (WIP)"""
from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(..., streaming=True)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42, buffer_size=1000)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = ds.state_dict()

# Resumable training loop
ds.load_state_dict(dataset_state_dict)
dataloader = DataLoader(ds, batch_size=batch_size)
for step, batch in enumerate(dataloader):
    ...
    if step % save_steps == 0:
        dataset_state_dict = ds.state_dict()
jwliu36 commented 8 months ago

Hi @lhoestq - can you provide more information and how to implement on saving and restoring vanilla DataLoader states with map-style datasets?

lhoestq commented 8 months ago

For now the easiest is probably to use the vanilla DataLoader only for batching and multiprocessing, and implement the resuming logic using a Dataset (it has .select() to skip examples) and a dataset_state_dict:

from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(...)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = {"step": 0}  

# Resumable training loop
start_step = dataset_state_dict["step"]
dataloader = DataLoader(ds.select(range(start_step * batch_size, len(ds))), batch_size=batch_size)
for step, batch in enumerate(dataloader, start=start_step):
    ...
    if step % save_steps == 0:
        dataset_state_dict = {"step": step}
xgbj commented 7 months ago

Hello, I found a similar implementation online that seems to solve your problem. https://github.com/facebookresearch/vissl/blob/main/vissl/data/data_helper.py#L93 it looks like we can set_start_iter in StatefulDistributedSampler to implement the stateful resume requirement we want.

andrewkho commented 5 months ago

Hi y'all, @lhoestq I wanted to flag that we currently have a StatefulDataLoader in pytorch/data/torchdata that has state_dict/load_state_dict methods, which will call a dataset's state_dict/load_state_dict methods but also handle multiprocessing under the hood. Any chance we can collaborate on this and try to get them to work well together? Please have a look here for some basic examples: https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader#saving-and-loading-state

lhoestq commented 5 months ago

Fantastic ! This will help pushing our IterableDataset state_dict implementation at https://github.com/huggingface/datasets/pull/6658 :) I'll check if there is anything missing to maker them work together, and add tests and some docs referring to the StatefulDataLoader :)

lhoestq commented 5 months ago

Ah I just saw this disclaimer in the torchdata README and it feels like people should not rely on it. Should the StatefulDataLoader live elsewhere @andrewkho ?

⚠️ As of July 2023, we have paused active development on TorchData and have paused new releases. We have learnt a lot from building it and hearing from users, but also believe we need to re-evaluate the technical design and approach given how much the industry has changed since we began the project. During the rest of 2023 we will be re-evaluating our plans in this space. Please reach out if you suggestions or comments (please use https://github.com/pytorch/data/issues/1196 for feedback).

andrewkho commented 5 months ago

@lhoestq Good find, we are in the midst of updating this disclaimer as we're re-starting development and regular releases, though our approach will be to iterate on DL V1 (ie StatefulDataLoader) instead of continuing development on datapipes+DLV2. Let's discuss on a call at some point to figure out the best path forward!

lhoestq commented 2 months ago

As a heads up, IterableDataset state_dict has been added in https://github.com/huggingface/datasets/pull/6658

...and it works out of the box with the torchdata StatefulDataLoader :)

See the docs at https://huggingface.co/docs/datasets/main/en/use_with_pytorch#checkpoint-and-resume

stas00 commented 2 months ago

amazing! Thank you, @lhoestq

does it work with non-iterable dataset as well? the docs only mention iterable dataset

lhoestq commented 2 months ago

It's for iterable dataset only. For regular dataset I believe the sampler should implement state_dict, but maybe @andrewkho might know best how to resume a regular dataset with torchdata

andrewkho commented 2 months ago

@stas00 stateful dataloader will save and resume samplers for map style datasets. If no state_dict/load_state_dict is provided by the sampler, it will naively skip samples to fast forward. See here for more details https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/README.md

Hope this helps!

stas00 commented 2 months ago

Thank you very much for clarifying that, Andrew.