huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.01k stars 2.63k forks source link

PyArrow Dataset error when calling `load_dataset` #4721

Open piraka9011 opened 2 years ago

piraka9011 commented 2 years ago

Describe the bug

I am fine tuning a wav2vec2 model following the script here using my own dataset: https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py

Loading my Audio dataset from the hub which was originally generated from disk results in the following PyArrow error:

File "/home/ubuntu/w2v2/run_speech_recognition_ctc.py", line 227, in main
  raw_datasets = load_dataset(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/load.py", line 1679, in load_dataset
  builder_instance.download_and_prepare(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/builder.py", line 704, in download_and_prepare
  self._download_and_prepare(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/builder.py", line 793, in _download_and_prepare
  self._prepare_split(split_generator, **prepare_split_kwargs)
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/builder.py", line 1268, in _prepare_split
  for key, table in logging.tqdm(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/tqdm/std.py", line 1195, in __iter__
  for obj in iterable:
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/packaged_modules/parquet/parquet.py", line 68, in _generate_tables
  for batch_idx, record_batch in enumerate(
File "pyarrow/_parquet.pyx", line 1309, in iter_batches
File "pyarrow/error.pxi", line 121, in pyarrow.lib.check_status
pyarrow.lib.ArrowNotImplementedError: Nested data conversions not implemented for chunked array outputs

Steps to reproduce the bug

I created a dataset from a JSON lines manifest of audio_filepath, text, and duration.

When creating the dataset, I do something like this:

import json
from datasets import Dataset, Audio

# manifest_lines is a list of dicts w/ "audio_filepath", "duration", and "text
for line in manifest_lines:
    line = line.strip()
    if line:
        line_dict = json.loads(line)
        manifest_dict["audio"].append(f"{root_path}/{line_dict['audio_filepath']}")
        manifest_dict["duration"].append(line_dict["duration"])
        manifest_dict["transcription"].append(line_dict["text"])

# Create a HF dataset
dataset = Dataset.from_dict(manifest_dict).cast_column(
    "audio", Audio(sampling_rate=16_000),
)

# From the docs for saving to disk
# https://huggingface.co/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Dataset.save_to_disk
def read_audio_file(example):
    with open(example["audio"]["path"], "rb") as f:
        return {"audio": {"bytes": f.read()}}

dataset = dataset.map(read_audio_file, num_proc=70)
dataset.save_to_disk(f"/audio-data/hf/{artifact_name}")
dataset.push_to_hub(f"{org-name}/{artifact_name}", max_shard_size="5GB", private=True)

Then when I call load_dataset() in my training script, with the same dataset I generated above, and download from the huggingface hub I get the above stack trace. I am able to load the dataset fine if I use load_from_disk().

Expected results

load_dataset() should behave just like load_from_disk() and not cause any errors.

Actual results

See above

Environment info

I am using the huggingface/transformers-pytorch-gpu:latest image

lhoestq commented 2 years ago

Hi ! It looks like a bug in pyarrow. If you manage to end up with only one chunk per parquet file it should workaround this issue.

To achieve that you can try to lower the value of max_shard_size and also don't use map before push_to_hub.

Do you have a minimum reproducible example that we can share with the Arrow team for further debugging ?

piraka9011 commented 2 years ago

If you manage to end up with only one chunk per parquet file it should workaround this issue.

Yup, I did not encounter this bug when I was testing my script with a slice of <1000 samples for my dataset.

Do you have a minimum reproducible example...

Not sure if I can get more minimal than the script I shared above. Are you asking for a sample json file? Just generate a random manifest list, I can add that to the above script if that's what you mean?

lhoestq commented 2 years ago

Actually this is probably linked to this open issue: https://issues.apache.org/jira/browse/ARROW-5030.

setting max_shard_size="2GB" should do the job (or max_shard_size="1GB" if you want to be on the safe side, especially given that there can be some variance in the shard sizes if the dataset is not evenly distributed)