huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.71k stars 25.53k forks source link

NonMatchingSplitsSizesError on Flax BART with wiki summary dataset #29596

Open RissyRan opened 3 months ago

RissyRan commented 3 months ago

System Info

Platform: TPU Python: python3.11

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

  1. set up example based on here
  2. run command below on wiki_summary
          "cd /tmp/transformers/examples/flax/summarization &&"
          " JAX_PLATFORM_NAME=TPU python3 run_summarization_flax.py"
          " --model_name_or_path facebook/bart-base --tokenizer_name"
          " facebook/bart-base --dataset_name wiki_summary --do_train"
          " --do_eval --do_predict --predict_with_generate --learning_rate"
          " 5e-5 --warmup_steps 0 --output_dir=/tmp/transformers/bart-base-wiki"
          " --overwrite_output_dir --num_train_epochs 3 --max_source_length"
          " 512 --max_target_length 64  --per_device_train_batch_size=64 --per_device_eval_batch_size=64"

Expected behavior

We got issue below:

[2024-03-11, 02:09:49 UTC] {logging_mixin.py:150} WARNING - 
Generating test split:  18%|█▊        | 1000/5638 [00:00<00:00, 8781.53 examples/s]
[2024-03-11, 02:09:49 UTC] {logging_mixin.py:150} WARNING - 
Generating test split:  37%|███▋      | 2074/5638 [00:00<00:00, 9863.29 examples/s]
[2024-03-11, 02:09:49 UTC] {logging_mixin.py:150} WARNING - 
Generating test split:  55%|█████▌    | 3112/5638 [00:00<00:00, 10087.86 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating test split:  74%|███████▍  | 4188/5638 [00:00<00:00, 10345.56 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating test split:  93%|█████████▎| 5253/5638 [00:00<00:00, 10450.02 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating test split: 100%|██████████| 5638/5638 [00:00<00:00, 10140.16 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating validation split:   0%|          | 0/5074 [00:00<?, ? examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating validation split:  20%|██        | 1035/5074 [00:00<00:00, 10319.92 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating validation split:  42%|████▏     | 2110/5074 [00:00<00:00, 10566.07 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating validation split:  63%|██████▎   | 3175/5074 [00:00<00:00, 10597.24 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating validation split:  84%|████████▍ | 4253/5074 [00:00<00:00, 10663.04 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - 
Generating validation split: 100%|██████████| 5074/5074 [00:00<00:00, 10458.45 examples/s]
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - Traceback (most recent call last):
  File "/tmp/transformers/examples/flax/summarization/run_summarization_flax.py", line 1031, in <module>
    main()
  File "/tmp/transformers/examples/flax/summarization/run_summarization_flax.py", line 499, in main
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING -     dataset = load_dataset(
  File "/home/ml-auto-solutions/.local/lib/python3.10/site-packages/datasets/load.py", line 2582, in load_dataset
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING -     builder_instance.download_and_prepare(
  File "/home/ml-auto-solutions/.local/lib/python3.10/site-packages/datasets/builder.py", line 1005, in download_and_prepare
    self._download_and_prepare(
  File "/home/ml-auto-solutions/.local/lib/python3.10/site-packages/datasets/builder.py", line 1767, in _download_and_prepare
    super()._download_and_prepare(
  File "/home/ml-auto-solutions/.local/lib/python3.10/site-packages/datasets/builder.py", line 1118, in _download_and_prepare
    verify_splits(self.info.splits, split_dict)
  File "/home/ml-auto-solutions/.local/lib/python3.10/site-packages/datasets/utils/info_utils.py", line 101, in verify_splits
    raise NonMatchingSplitsSizesError(str(bad_splits))
datasets.utils.info_utils.NonMatchingSplitsSizesError: [{'expected': SplitInfo(name='train', num_bytes=207186608, num_examples=45654, shard_lengths=None, dataset_name=None), 'recorded': SplitInfo(name='train', num_bytes=0, num_examples=0, shard_lengths=None, dataset_name='wiki_summar
[2024-03-11, 02:09:50 UTC] {logging_mixin.py:150} WARNING - y')}]
[2024-03-11, 02:09:54 UTC] {taskinstance.py:1826} ERROR - Task failed with exception
Traceback (most recent call last):
  File "/opt/python3.11/lib/python3.11/site-packages/airflow/decorators/base.py", line 220, in execute
    return_value = super().execute(context)
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/python3.11/lib/python3.11/site-packages/airflow/operators/python.py", line 181, in execute
    return_value = self.execute_callable()
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/python3.11/lib/python3.11/site-packages/airflow/operators/python.py", line 198, in execute_callable
    return self.python_callable(*self.op_args, **self.op_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/airflow/gcs/dags/xlml/utils/tpu.py", line 402, in ssh_tpu
    ssh_group.run(cmds, env=env)
  File "/opt/python3.11/lib/python3.11/site-packages/fabric/group.py", line 116, in run
    return self._do("run", *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/python3.11/lib/python3.11/site-packages/fabric/group.py", line 282, in _do
    raise GroupException(results)

Based on this thread, a few flags can be added to load_dataset().

amyeroberts commented 2 months ago

Gentle ping @sanchit-gandhi

amyeroberts commented 1 month ago

Another ping @sanchit-gandhi