huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.29k stars 2.7k forks source link

Make `BufferShuffledExamplesIterable` resumable #7056

Open yzhangcs opened 4 months ago

yzhangcs commented 4 months ago

This PR aims to implement a resumable BufferShuffledExamplesIterable. Instead of saving the entire buffer content, which is very memory-intensive, the newly implemented BufferShuffledExamplesIterable saves only the minimal state necessary for recovery, e.g., the random generator states and the state of the first example in the buffer dict.

The idea is that since the buffer size is limited, even if the entire buffer is discarded, we can rebuild it as long as the state of the oldest example is recorded. For buffer size $B$, the expected distance between when an example is pushed and when it is yielded is $d = \sum_{k=1}^{\infty} k\frac{1}{B} (1 - \frac{1}{B} )^{k-1} =B$. Simulation experiments support these claims:

from random import randint

BUFFER_SIZE = 1024

dists = []
buffer = []
for i in range(10000000):
    if i < BUFFER_SIZE:
        buffer.append(i)
    else:
        index = randint(0, BUFFER_SIZE - 1)
        dists.append(i - buffer[index])
        buffer[index] = i

print(f"MIN DIST: {min(dists)}\nMAX DIST: {max(dists)}\nAVG DIST: {sum(dists) / len(dists):.2f}\n")

which produces the following output:

MIN DIST: 1
MAX DIST: 15136
AVG DIST: 1023.95

The overall time for reconstructing the buffer and recovery should not be too long. The following code mimics the cases of resuming online tokenization by datasets and StatefulDataLoader under distributed scenarios,

import pickle
import time
from itertools import chain
from typing import Any, Dict, List

import torch
from datasets import load_dataset
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

tokenizer = AutoTokenizer.from_pretrained('fla-hub/gla-1.3B-100B')
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

torch.manual_seed(42)

def tokenize(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
    input_ids = tokenizer(examples['text'])['input_ids']
    input_ids = list(chain(*input_ids))
    total_length = len(input_ids)
    chunk_size = 2048
    total_length = (total_length // chunk_size) * chunk_size
    # the last chunk smaller than chunk_size will be discarded
    return {'input_ids': [input_ids[i: i+chunk_size] for i in range(0, total_length, chunk_size)]}

batch_size = 16
num_workers = 5
context_length = 2048
rank = 1
world_size = 32
prefetch_factor = 2
steps = 2048
path = 'fla-hub/slimpajama-test'
dataset = load_dataset(
    path=path,
    split='train',
    streaming=True,
    trust_remote_code=True
)
dataset = dataset.map(tokenize, batched=True, remove_columns=next(iter(dataset)).keys())
dataset = dataset.shuffle(seed=42)
loader = StatefulDataLoader(dataset=dataset,
                            batch_size=batch_size,
                            collate_fn=data_collator,
                            num_workers=num_workers,
                            persistent_workers=False,
                            prefetch_factor=prefetch_factor)
start = time.time()
for i, batch in tqdm(enumerate(loader)):
    if i == 0:
        print(f'{i}\n{batch["input_ids"]}')
    if i == steps - 1:
        print(f'{i}\n{batch["input_ids"]}')
        state_dict = loader.state_dict()
    if i == steps:
        print(f'{i}\n{batch["input_ids"]}')
        break
print(f"{time.time() - start:.2f}s elapsed")
print(f"{len(pickle.dumps(state_dict)) / 1024**2:.2f}MB states in total")
for worker in state_dict['_snapshot']['_worker_snapshots'].keys():
    print(f"{worker} {len(pickle.dumps(state_dict['_snapshot']['_worker_snapshots'][worker])) / 1024**2:.2f}MB")
print(state_dict['_snapshot']['_worker_snapshots']['worker_0']['dataset_state'])

loader = StatefulDataLoader(dataset=dataset,
                            batch_size=batch_size,
                            collate_fn=data_collator,
                            num_workers=num_workers,
                            persistent_workers=False,
                            prefetch_factor=prefetch_factor)
print("Loading state dict")
loader.load_state_dict(state_dict)
start = time.time()
for batch in loader:
    print(batch['input_ids'])
    break

print(f"{time.time() - start:.2f}s elapsed")

and the outputs are

0
tensor([[  909,   395, 19082,  ..., 13088, 16232,   395],
        [  601, 28705, 28770,  ..., 28733,   923,   288],
        [21753, 15071, 13977,  ...,  9369, 28723,   415],
        ...,
        [21763, 28751, 20300,  ..., 28781, 28734,  4775],
        [  354,   396, 10214,  ...,   298,   429, 28770],
        [  333,  6149, 28768,  ...,  2773,   340,   351]])
2047
tensor([[28723,   415,  3889,  ...,   272,  3065,  2609],
        [  403,  3214,  3629,  ...,   403, 21163, 16434],
        [28723,    13, 28749,  ..., 28705, 28750, 28734],
        ...,
        [ 2778,  2251, 28723,  ...,   354,   684,   429],
        [ 5659,   298,  1038,  ...,  5290,   297, 22153],
        [  938, 28723,  1537,  ...,  9123, 28733, 12154]])
2048
tensor([[  769,   278, 12531,  ..., 28721, 19309, 28739],
        [  415, 23347,   622,  ...,  3937,  2426, 28725],
        [28745,  4345, 28723,  ...,   338, 28725,   583],
        ...,
        [ 1670, 28709,  5809,  ..., 28734, 28760,   393],
        [  340,  1277,   624,  ...,   325, 28790,  1329],
        [  523,  1144,  3409,  ...,   359,   359, 17422]])
65.97s elapsed
0.00MB states in total
worker_0 0.00MB
worker_1 0.00MB
worker_2 0.00MB
worker_3 0.00MB
worker_4 0.00MB
{'ex_iterable': {'ex_iterable': {'shard_idx': 0, 'shard_example_idx': 14000}, 'num_examples_since_previous_state': 166, 'previous_state_example_idx': 7394, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 13000}}, 'num_taken': 6560, 'global_example_idx': 7560, 'buffer_state_dict': {'num_taken': 6560, 'global_example_idx': 356, 'index_offset': 0, 'first_state': {'ex_iterable': {'shard_idx': 0, 'shard_example_idx': 1000}, 'num_examples_since_previous_state': 356, 'previous_state_example_idx': 0, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0}}, 'bit_generator_state': {'state': {'state': 274674114334540486603088602300644985544, 'inc': 332724090758049132448979897138935081983}, 'bit_generator': 'PCG64', 'has_uint32': 0, 'uinteger': 0}}}
Loading state dict
tensor([[  769,   278, 12531,  ..., 28721, 19309, 28739],
        [  415, 23347,   622,  ...,  3937,  2426, 28725],
        [28745,  4345, 28723,  ...,   338, 28725,   583],
        ...,
        [ 1670, 28709,  5809,  ..., 28734, 28760,   393],
        [  340,  1277,   624,  ...,   325, 28790,  1329],
        [  523,  1144,  3409,  ...,   359,   359, 17422]])
24.60s elapsed

Not sure if this PR complies with the datasets code style. Looking for your help @lhoestq, also very willing to further improve the code if any suggestions are given.

lhoestq commented 4 months ago

Oh cool !

The time it takes to resume depends on the expected maximum distance in this case right ? Do you know its relationship with $B$ ?

In your test it already as high as 15k for $B=1024$, which is ok for text datasets but is maybe not ideal for datasets with heavy samples like audio/image/video ? Though for heavy samples datasets the buffer size is generally much smaller to avoid memory issues.

Maybe we could just add a warning message on resuming to tell the user that it might take some time to recover the shuffle buffer (with a progress bar maybe ?), and have the option to stop + re-run with an env variable to disable shuffle buffer recovering ? WDYT ?

yzhangcs commented 4 months ago

The time it takes to resume depends on the expected maximum distance in this case right ? Do you know its relationship with $B$

Hi, I created a histogram to visualize the distances in the simulation exp. I think there is no guarantee as to when the oldest example will be yielded. It could stay in the buffer until the entire shard is consumed. However, this can be rare, and in most cases, the pushed examples will be yielded very quickly. In the figure above, most examples are yielded within $2B$ steps. Things will improve if the dataset is split into enough shards and each shard is not too large.

I agree that we may need to add some warnings or provide some options to allow users to make their own choices.

yzhangcs commented 4 months ago

Maybe there's a middle ground between rebuilding the buffer from scratch and storing the entire buffer, but the logic is a bit complicated and takes time to implement. At least for now, we have a way to make shuffled IterableDataset resumable :)

yzhangcs commented 4 months ago

@lhoestq I'm not sure if it's ok to use progress bar when having multiple workers. How about passing an arg resumable=True to IterableDataset.shuffle to allow for controling of the behaviors?

lhoestq commented 4 months ago

I feel like the default behavior should ideally be fast and perfect resuming.

Loading from disk is a good option for this (although it's not always possible to serialize the content of the buffer, in that case the buffer would restart empty and we can show a warning).

The state_dict() would be part of the training state_dict that is saved to disk along with the model and optimizer anyway. Cc @muellerzr from that worked on storing training state_dicts for the accelerate lib, in case you have an opinion.

I also feel like it is simpler and more intuitive to users. It doesn't require to explain why we need to stream a lot of data just to recover a buffer.

Maybe there's a middle ground between rebuilding the buffer from scratch and storing the entire buffer, but the logic is a bit complicated and takes time to implement.

definitely, and it would also make things even harder to understand to users

yzhangcs commented 4 months ago

@lhoestq

Loading from disk is a good option for this (although it's not always possible to serialize the content of the buffer, in that case the buffer would restart empty and we can show a warning). The state_dict() would be part of the training state_dict that is saved to disk along with the model and optimizer anyway. Cc @muellerzr from that worked on storing training state_dicts for the accelerate lib, in case you have an opinion. I also feel like it is simpler and more intuitive to users. It doesn't require to explain why we need to stream a lot of data just to recover a buffer.

Yea, agree with you. But here's the thing: saving buffers as state dict can get pretty tricky. When it comes to tokenized text data, working with multi-worker shuffle can take around x hundreds GB of memories in my case. That's just not feasible for most machine envs out there, and can be more severe for audio/video data.

Also, serializing the buffer does take a major toll on performance, and in my experience, I've had to lean heavily on numpy/torch tensor operations to manage those tokenized text data efficiently, which isn't easily transferable to other scenarios—it's kind of a custom fix that works for now, but it's not a one-size-fits-all solution. So, for me it's not that ideal to directly serialize the buffer content with those limitations.

lhoestq commented 4 months ago

When it comes to tokenized text data, working with multi-worker shuffle can taken around x hundreds GB memories in my case.

it's kinda close to the size of a model + optimizer no ?

Anyway that makes sense and adding the feature to recover a buffer shuffle (at least as an opt-in for now, we can decide on the default later based on users feedback and experience).

Are you ok with adding buffer_resuming_mode= to .shuffle() to enable buffer recovering using your method with buffer_resuming_mode="recover_from_source" ? (feel free to suggest other names for the parameter and value)

yzhangcs commented 4 months ago

@lhoestq

Are you ok with adding buffer_resuming_mode= to .shuffle() to enable buffer recovering using your method with buffer_resuming_mode="recover_from_source" ? (feel free to suggest other names for the parameter and value)

Of course, appreciate your feedbacks.