huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.22k stars 2.68k forks source link

Streaming IterableDatasets do not work with torch DataLoaders #5720

Open jlehrer1 opened 1 year ago

jlehrer1 commented 1 year ago

Describe the bug

When using streaming datasets set up with train/val split using .skip() and .take(), the following error occurs when iterating over a torch dataloader:

  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 363, in __iter__
    self._iterator = self._get_iterator()
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 314, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 927, in __init__
    w.start()
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/julian/miniconda3/envs/sims/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object '_generate_examples_from_tables_wrapper.<locals>.wrapper'

To reproduce, run the code

from datasets import load_dataset
data = load_dataset(args.dataset_name, split="train", streaming=True)
train_len = 5000
val_len = 100

train, val = data.take(train_len), data.skip(train_len).take(val_len)
traindata = IterableClipDataset(data, context_length=args.max_len, tokenizer=tokenizer, image_key="url", text_key="text")

traindata = DataLoader(traindata, batch_size=args.batch_size, num_workers=args.num_workers, persistent_workers=True)

Where the class IterableClipDataset is a simple wrapper to cast the dataset to a torch iterabledataset, defined via

from torch.utils.data import Dataset, IterableDataset
from torchvision.transforms import Compose, Resize, ToTensor
from transformers import AutoTokenizer
import requests
from PIL import Image

class IterableClipDataset(IterableDataset):
    def __init__(self, dataset, context_length: int, image_transform=None, tokenizer=None, image_key="image", text_key="text"):
        self.dataset = dataset
        self.context_length = context_length
        self.image_transform = Compose([Resize((224, 224)), ToTensor()]) if image_transform is None else image_transform
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") if tokenizer is None else tokenizer
        self.image_key = image_key
        self.text_key = text_key

    def read_image(self, url: str):
        try: # Try to read the image
            image = Image.open(requests.get(url, stream=True).raw)
        except:
            image = Image.new("RGB", (224, 224), (0, 0, 0))
        return image

    def process_sample(self, image, text):
        if isinstance(image, str):
            image = self.read_image(image)
        if self.image_transform is not None:
            image = self.image_transform(image)
        text = self.tokenizer.encode(
            text, add_special_tokens=True, max_length=self.context_length, truncation=True, padding="max_length"
        )
        text = torch.tensor(text, dtype=torch.long)
        return image, text

    def __iter__(self):
        for sample in self.dataset:
            image, text = sample[self.image_key], sample[self.text_key]
            yield self.process_sample(image, text)

Steps to reproduce the bug

Steps to reproduce

  1. Install datasets, torch, and PIL (if you want to reproduce exactly)
  2. Run the code above

Expected behavior

Batched data is produced from the dataloader

Environment info

datasets == 2.9.0
python == 3.9.12
torch == 1.11.0
jlehrer1 commented 1 year ago

Edit: This behavior is true even without .take/.set

ivanprado commented 1 year ago

I'm experiencing the same problem that @jlehrer1. I was able to reproduce it with a very small example:

from datasets import Dataset, load_dataset, load_dataset_builder
from torch.utils.data import DataLoader

def my_gen():
    for i in range(1, 4):
        yield {"a": i}

# Saving the dataset as a parquet file
dataset = Dataset.from_generator(my_gen)
train_path = "/tmp/test.parquet"
dataset.to_parquet(train_path)

# Creating a local dataset from the parquet file
data_files = {"train": [str(train_path)]}
builder = load_dataset_builder("parquet", data_files=data_files)
builder.download_and_prepare("/tmp/test_ds", file_format="parquet")

# Loading the dataset from the local directory as streaming
dataset = load_dataset("parquet", data_dir="/tmp/test_ds", split="train", streaming=True)
dataset.with_format("torch")

dl = DataLoader(dataset, batch_size=2, num_workers=1)
for row in dl:
    print(row)

My env info:

datasets                      2.11.0
torch                         2.0.0
torchvision                   0.15.1
Python 3.9.16

Note that the example above doesn't fail if the number of workers used is 0

mariosasko commented 1 year ago

I cannot reproduce this error, not even with your MRE @ivanprado (your env appears to be the same as Colab's, and your code runs there without issues).

ivanprado commented 1 year ago

@mariosasko you are right, it works on Colab. I digged deeper and found that the problem arises when the multiprocessing method is set to be spawn. This code reproduces the problem in Colab:

from datasets import Dataset, load_dataset, load_dataset_builder
from torch.utils.data import DataLoader
import multiprocessing as mp

mp.set_start_method('spawn')

def my_gen():
    for i in range(1, 4):
        yield {"a": i}

def main():
    # Saving the dataset as a parquet file
    dataset = Dataset.from_generator(my_gen)
    train_path = "/tmp/test.parquet"
    dataset.to_parquet(train_path)

    # Creating a local dataset from the parquet file
    data_files = {"train": [str(train_path)]}
    builder = load_dataset_builder("parquet", data_files=data_files)
    builder.download_and_prepare("/tmp/test_ds", file_format="parquet")

    # Loading the dataset from the local directory as streaming
    dataset = load_dataset("parquet", data_dir="/tmp/test_ds", split="train", streaming=True)
    dataset.with_format("torch")

    dl = DataLoader(dataset, batch_size=2, num_workers=1)
    for row in dl:
        print(row)

main()
jlehrer1 commented 1 year ago

So is there a way to fix this by changing the mp method? This is blocking any usage of the datasets library for me

ivanprado commented 1 year ago

@jlehrer1 can you try adding mp.set_start_method('fork') at the beginning of your code? Maybe this helps you. Keep us posted.

Wonder1905 commented 1 year ago

I have a similar issue:

mp.set_start_method('fork')

Didnt work