microsoft / ProphetNet

A research project for natural language generation, containing the official implementations by MSRA NLC team.
MIT License
654 stars 105 forks source link

Assertion Error in fine-tuning of Gigaword #31

Closed takase closed 3 years ago

takase commented 3 years ago

Hi, thank you for distributing your code! I tried to fine-tune the pre-trained ProphetNet (160G) on English Gigaword summarization dataset. I conducted pre-processing described in README and then tried fine-tuning but faced the following Assertion Error:

  File "~/anaconda3/envs/py36pytorch14/bin/fairseq-train", line 8, in <module>
    sys.exit(cli_main())
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq_cli/train.py", line 333, in cli_main
    main(args)
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq_cli/train.py", line 86, in main
    train(args, trainer, task, epoch_itr)
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq_cli/train.py", line 126, in train
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/tqdm/std.py", line 1127, in __iter__
    for obj in iterable:
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/iterators.py", line 314, in __next__
    chunk.append(next(self.itr))
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/iterators.py", line 43, in __next__
    return next(self.itr)
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/iterators.py", line 36, in __iter__
    for x in self.iterable:
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/language_pair_dataset.py", line 252, in collater
    input_feeding=self.input_feeding,
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/language_pair_dataset.py", line 69, in collate
    move_eos_to_beginning=True,
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/language_pair_dataset.py", line 22, in merge
    pad_idx, eos_idx, left_pad, move_eos_to_beginning,
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/data_utils.py", line 44, in collate_tokens
    copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
  File "~/anaconda3/envs/py36pytorch14/lib/python3.6/site-packages/fairseq/data/data_utils.py", line 37, in copy_tensor
    assert src[-1] == eos_idx
AssertionError

pytorch version == 1.4.0

fairseq version == 0.9.0

In addition, when I tried to train the original Transformer (--arch transformer_wmt_en_de) with label_smoothed_cross_entropy, I succeeded training.

Do you have any idea to solve the above error?

takase commented 3 years ago

I found that eos_indices of a dictionary and training data are different from each other. Do I need to modify bert_dictionary.py?

takase commented 3 years ago

I'm so sorry, I ignored the task option during pre-processing. I attached the option, and then solved the above problem.