Open jlehrer1 opened 1 year ago
Edit: This behavior is true even without .take/.set
I'm experiencing the same problem that @jlehrer1. I was able to reproduce it with a very small example:
from datasets import Dataset, load_dataset, load_dataset_builder
from torch.utils.data import DataLoader
def my_gen():
for i in range(1, 4):
yield {"a": i}
# Saving the dataset as a parquet file
dataset = Dataset.from_generator(my_gen)
train_path = "/tmp/test.parquet"
dataset.to_parquet(train_path)
# Creating a local dataset from the parquet file
data_files = {"train": [str(train_path)]}
builder = load_dataset_builder("parquet", data_files=data_files)
builder.download_and_prepare("/tmp/test_ds", file_format="parquet")
# Loading the dataset from the local directory as streaming
dataset = load_dataset("parquet", data_dir="/tmp/test_ds", split="train", streaming=True)
dataset.with_format("torch")
dl = DataLoader(dataset, batch_size=2, num_workers=1)
for row in dl:
print(row)
My env info:
datasets 2.11.0
torch 2.0.0
torchvision 0.15.1
Python 3.9.16
Note that the example above doesn't fail if the number of workers used is 0
I cannot reproduce this error, not even with your MRE @ivanprado (your env appears to be the same as Colab's, and your code runs there without issues).
@mariosasko you are right, it works on Colab. I digged deeper and found that the problem arises when the multiprocessing method is set to be spawn
. This code reproduces the problem in Colab:
from datasets import Dataset, load_dataset, load_dataset_builder
from torch.utils.data import DataLoader
import multiprocessing as mp
mp.set_start_method('spawn')
def my_gen():
for i in range(1, 4):
yield {"a": i}
def main():
# Saving the dataset as a parquet file
dataset = Dataset.from_generator(my_gen)
train_path = "/tmp/test.parquet"
dataset.to_parquet(train_path)
# Creating a local dataset from the parquet file
data_files = {"train": [str(train_path)]}
builder = load_dataset_builder("parquet", data_files=data_files)
builder.download_and_prepare("/tmp/test_ds", file_format="parquet")
# Loading the dataset from the local directory as streaming
dataset = load_dataset("parquet", data_dir="/tmp/test_ds", split="train", streaming=True)
dataset.with_format("torch")
dl = DataLoader(dataset, batch_size=2, num_workers=1)
for row in dl:
print(row)
main()
So is there a way to fix this by changing the mp
method? This is blocking any usage of the datasets
library for me
@jlehrer1 can you try adding mp.set_start_method('fork')
at the beginning of your code? Maybe this helps you. Keep us posted.
I have a similar issue:
mp.set_start_method('fork')
Didnt work
Describe the bug
When using streaming datasets set up with train/val split using
.skip()
and.take()
, the following error occurs when iterating over a torch dataloader:To reproduce, run the code
Where the class IterableClipDataset is a simple wrapper to cast the dataset to a torch iterabledataset, defined via
Steps to reproduce the bug
Steps to reproduce
datasets
,torch
, andPIL
(if you want to reproduce exactly)Expected behavior
Batched data is produced from the dataloader
Environment info