After concatenate datasets on an iterable dataset, the shuffling state is destroyed, similar to #7156
This means concatenation cant be used for resolving uneven numbers of samples across devices when using iterable datasets in a distributed setting as discussed in #6623
I also noticed that the number of shards is the same after concatenation, which I found surprising, but I don't understand the internals well enough to know whether this is actually surprising or not
Steps to reproduce the bug
import datasets
import torch.utils.data
def gen(shards):
yield {"shards": shards}
def main():
dataset1 = datasets.IterableDataset.from_generator(
gen, gen_kwargs={"shards": list(range(25))} # TODO: how to understand this?
)
dataset2 = datasets.IterableDataset.from_generator(
gen, gen_kwargs={"shards": list(range(25, 50))} # TODO: how to understand this?
)
dataset1 = dataset1.shuffle(buffer_size=1)
dataset2 = dataset2.shuffle(buffer_size=1)
print(dataset1.n_shards)
print(dataset2.n_shards)
dataset = datasets.concatenate_datasets(
[dataset1, dataset2]
)
print(dataset.n_shards)
# dataset = dataset1
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
num_workers=0,
)
for i, batch in enumerate(dataloader):
print(batch)
print("\nNew epoch")
dataset = dataset.set_epoch(1)
for i, batch in enumerate(dataloader):
print(batch)
if __name__ == "__main__":
main()
Describe the bug
After concatenate datasets on an iterable dataset, the shuffling state is destroyed, similar to #7156
This means concatenation cant be used for resolving uneven numbers of samples across devices when using iterable datasets in a distributed setting as discussed in #6623
I also noticed that the number of shards is the same after concatenation, which I found surprising, but I don't understand the internals well enough to know whether this is actually surprising or not
Steps to reproduce the bug
Expected behavior
Shuffling state should be preserved
Environment info
Latest datasets