huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.29k stars 2.7k forks source link

concatenate_datasets does not preserve shuffling state #7196

Open alex-hh opened 1 month ago

alex-hh commented 1 month ago

Describe the bug

After concatenate datasets on an iterable dataset, the shuffling state is destroyed, similar to #7156

This means concatenation cant be used for resolving uneven numbers of samples across devices when using iterable datasets in a distributed setting as discussed in #6623

I also noticed that the number of shards is the same after concatenation, which I found surprising, but I don't understand the internals well enough to know whether this is actually surprising or not

Steps to reproduce the bug

import datasets
import torch.utils.data

def gen(shards):
    yield {"shards": shards}

def main():
    dataset1 = datasets.IterableDataset.from_generator(
        gen, gen_kwargs={"shards": list(range(25))}  # TODO: how to understand this?
    )
    dataset2 = datasets.IterableDataset.from_generator(
        gen, gen_kwargs={"shards": list(range(25, 50))}  # TODO: how to understand this?
    )
    dataset1 = dataset1.shuffle(buffer_size=1)
    dataset2 = dataset2.shuffle(buffer_size=1)
    print(dataset1.n_shards)
    print(dataset2.n_shards)

    dataset = datasets.concatenate_datasets(
        [dataset1, dataset2]
    )
    print(dataset.n_shards)
    # dataset = dataset1

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=8,
        num_workers=0,
    )

    for i, batch in enumerate(dataloader):
        print(batch)
    print("\nNew epoch")

    dataset = dataset.set_epoch(1)

    for i, batch in enumerate(dataloader):
        print(batch)

if __name__ == "__main__":
    main()

Expected behavior

Shuffling state should be preserved

Environment info

Latest datasets