facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.39k stars 6.4k forks source link

How to reproduce wmt19 BLEU score in your submission? #2544

Closed stas00 closed 2 years ago

stas00 commented 4 years ago

I'm trying to reproduce the BLEU score reported at http://matrix.statmt.org/matrix/output/1914?score_id=37605 and described here https://github.com/pytorch/fairseq/tree/master/examples/wmt19

Would it be possible to add the instructions for getting those reported scores?

Here is what I have tried and I'm not getting anywhere close to the reported scores:

Let's try wmt19 en-ru pair, which is reported as BLEU 36.4

# en-ru
git clone https://github.com/pytorch/fairseq/
cd fairseq
mkdir -p data-bin

# get model
curl --output - https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz | tar xvzf - -C data_bin

export PAIR=en-ru
export DATA_DIR=data-bin/wmt19.en-ru.ensemble

# get evaluation data
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/test.en-ru.en
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/test.en-ru.ru

Generate (w/ 4 model ensemble)

fairseq-generate $DATA_DIR --path $DATA_DIR/model1.pt:$DATA_DIR/model2.pt:$DATA_DIR/model3.pt:$DATA_DIR/model4.pt \
--beam 5 --batch-size 32 --remove-bpe --source-lang en --target-lang ru --task translation --dataset-impl raw | tee /tmp/gen.out.models.4

Evaluate via scripts/sacrebleu.sh:

bash scripts/sacrebleu.sh wmt19 en ru /tmp/gen.out.models.4 | tee /tmp/bleu.out.models.4

gives:

BLEU+case.mixed+lang.en-ru+numrefs.1+smooth.exp+test.wmt19+tok.13a+version.1.4.12 = 18.3 55.2/29.3/18.1/11.5 (BP = 0.760 ratio = 0.785 hyp_len = 37775 ref_len = 48147)

Evaluate via fairseq-score:

grep ^H /tmp/gen.out.models.4 | cut -f3- > /tmp/gen.out.models.4.sys
grep ^T /tmp/gen.out.models.4 | cut -f2- > /tmp/gen.out.models.4.ref
fairseq-score --sys /tmp/gen.out.models.4.sys --ref /tmp/gen.out.models.4.ref

gives:

BLEU4 = 7.93, 31.9/12.5/5.3/2.3 (BP=0.948, ratio=0.950, syslen=38345, reflen=40379)

What am I doing wrong? How can I get a similar to the reported score?

Thank you very much!

What's your environment?

stas00 commented 4 years ago

@edunov, would it be possible to give us hand here, please? We are in the final stages of porting fairseq wmt19 transformer to transformers: https://github.com/huggingface/transformers/pull/6940, but since the latter currently doesn't support model ensemble we are trying to get the best we can with what we have - a single model - so being able to reproduce your BLEU score using fairseq-generate would be very helpful to us. Thank you!

All we need is the command line that you used to get those scores. Thank you!

edunov commented 4 years ago

Hi @stas00, Sorry, didn't notice your message the first time. There are several steps missing in your command:

  1. Data needs to be normalized and tokenized

cat $DATA_DIR/test.en-ru.en | ~/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l en | ~/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l en -q > $DATA_DIR/temp.en-ru.en

cat $DATA_DIR/test.en-ru.ru | ~/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l ru | ~/mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l ru -q > $DATA_DIR/temp.en-ru.ru

  1. Then you need to apply BPE, in this case we used fastBPE, and also I just noticed that the downloaded tar.gz file with models doesn't have correct BPE codes, so temporary I put correct BPE codes online, and I'll fix the tar.gz this week:

git clone git@github.com:glample/fastBPE.git

download BPE codes

wget https://dl.fbaipublicfiles.com/fairseq/ru24k.fastbpe.code wget https://dl.fbaipublicfiles.com/fairseq/en24k.fastbpe.code

apply BPE

fastBPE/fast applybpe $DATA_DIR/test.en-ru.en $DATA_DIR/temp.en-ru.en en24k.fastbpe.code fastBPE/fast applybpe $DATA_DIR/test.en-ru.ru $DATA_DIR/temp.en-ru.ru ru24k.fastbpe.code

then you can run generate, the way you did:

fairseq-generate $DATA_DIR --path $DATA_DIR/model1.pt:$DATA_DIR/model2.pt:$DATA_DIR/model3.pt:$DATA_DIR/model4.pt --beam 5 --batch-size 32 --remove-bpe --source-lang en --target-lang ru --task translation --dataset-impl raw | tee /tmp/gen.out.models.4

Finally, for the evaluation, you'll need to detokenize:

cat /tmp/gen.out.models.4 |grep ^H | sort -nr -k1.2 | cut -f3- | ~/mosesdecoder/scripts/tokenizer/detokenizer.perl | sacrebleu -t wmt19 -l $PAIR

That should give you everything except reranking, so the BLEU score will still be a little short of the one reported in matrix, but reranking scripts are more cumbersome to run, and we didn't release them. We're working on releasing a better reranking approach.

stas00 commented 4 years ago

Fantastic, thank you so much for these explicit extra steps, @edunov. I am getting now 35.7 with fairseq, yay! As you said slightly below 36.4

Currently all 4 tar.gz model files (not just this pair) have only one bpe codes set, and need both for this to work. I hope they will all have them when you complete the update.

p.s. if you are updating tar.gz files, if it's not too much trouble, you may consider removing 8.8GB of optimizer state from the model (4 x 2.2GB), that will make the download for the end user much faster. But if it's too much trouble, it's good as it is.

stas00 commented 4 years ago

Here is the full script. I will update it when you release the updated tgz files.

# en-ru
git clone https://github.com/pytorch/fairseq/
cd fairseq

git clone https://github.com/moses-smt/mosesdecoder
git clone git@github.com:glample/fastBPE.git
cd fastBPE; g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast; cd -

mkdir -p data-bin

# get model
curl --output - https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz | tar xvzf - -C data_bin

export PAIR=en-ru
export DATA_DIR=data-bin/wmt19.en-ru.ensemble

# get evaluation data
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/test.en-ru.en
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/test.en-ru.ru

# normalize and tokenize eval data
cat $DATA_DIR/test.en-ru.en | ./mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l en | ./mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l en -q > $DATA_DIR/temp.en-ru.en
cat $DATA_DIR/test.en-ru.ru | ./mosesdecoder/scripts/tokenizer/normalize-punctuation.perl -l ru | ./mosesdecoder/scripts/tokenizer/tokenizer.perl -a -l ru -q > $DATA_DIR/temp.en-ru.ru

# download BPE codes (temporary, wmt19.en-ru.ensemble.tar.gz is waiting to be fixed to include the right codes), otherwise will use those in DATA_DIR
wget -P $DATA_DIR https://dl.fbaipublicfiles.com/fairseq/ru24k.fastbpe.code
wget -P $DATA_DIR https://dl.fbaipublicfiles.com/fairseq/en24k.fastbpe.code
# apply BPE
./fastBPE/fast applybpe $DATA_DIR/test.en-ru.en $DATA_DIR/temp.en-ru.en $DATA_DIR/en24k.fastbpe.code
./fastBPE/fast applybpe $DATA_DIR/test.en-ru.ru $DATA_DIR/temp.en-ru.ru $DATA_DIR/ru24k.fastbpe.code

# which checkpoints to eval against (all or just a specific one)
export CHKPT=$DATA_DIR/model1.pt:$DATA_DIR/model2.pt:$DATA_DIR/model3.pt:$DATA_DIR/model4.pt
#export CHKPT=$DATA_DIR/model4.pt

# generate (w/ 4 model ensemble)
fairseq-generate $DATA_DIR --path $CHKPT --beam 5 --batch-size 32 --remove-bpe --source-lang en --target-lang ru \
--task translation --dataset-impl raw | tee /tmp/gen.out.models.4

# detokenize + eval bleu
cat /tmp/gen.out.models.4 | grep ^H | sort -nr -k1.2 | cut -f3- | ./mosesdecoder/scripts/tokenizer/detokenizer.perl | sacrebleu -t wmt19 -l $PAIR

Output:

Detokenizer Version $Revision: 4134 $
Language: en
BLEU+case.mixed+lang.en-ru+numrefs.1+smooth.exp+test.wmt19+tok.13a+version.1.4.12 = 35.7 65.3/43.6/31.3/22.8 (BP = 0.947 ratio = 0.948 hyp_len = 45640 ref_len = 48147)
stas00 commented 4 years ago

@edunov, if you don't mind - one more question - how can I find which params were you using besides those stored in the checkpoints' args?

I derived from your paper to use num of beams 50, but for example the paper mentions "length penalty", but I couldn't find what final (after the search) numbers were used. We currently use no penalty, and are behind on the final score, so sharing this and other custom configs that aren't in args would be extremely useful for us.

edit: Re-reading the section 3.4 of the paper I think that what it says is that you dynamically calculate "length penalty" at generation run time, depending on the inputs, when you assemble the models, and so this is basically not accessible to us if we don't implement the model assembly. Did I get it correctly this time around, or is there a way we could improve our score besided using the provided by you weights and model args.

Thank you!

babangain commented 4 years ago

I am trying to reproduce the same. except for en-de

When I run this command

!fairseq-generate data-bin/wmt19.en-de.joined-dict.single_model --path data-bin/wmt19.en-de.joined-dict.single_model/model.pt --beam 5 --batch-size 32 --remove-bpe --source-lang en --target-lang de \
--task translation --dataset-impl raw | tee /tmp/gen.out.models

I get,

tcmalloc: large alloc 1078984704 bytes == 0x4ee48000 @  0x7f4f7a80db6b 0x7f4f7a82d379 0x7f4f2ef9a92e 0x7f4f2ef9c946 0x7f4f67377a05 0x7f4f76a4eb49 0x551755 0x5a9eec 0x50a783 0x50c1f4 0x507f24 0x509202 0x5a4d81 0x5a50d8 0x4e01be 0x50a7b1 0x50c1f4 0x507f24 0x588fac 0x59fe1e 0x50d596 0x507f24 0x509c50 0x50a64d 0x50cfd6 0x507f24 0x509c50 0x50a64d 0x50c1f4 0x507f24 0x509c50
  0% 0/63 [00:00<?, ?it/s]Namespace(beam=5, bpe=None, cpu=False, criterion='cross_entropy', data='data-bin/wmt19.en-de.joined-dict.single_model', dataset_impl='raw', decoding_format=None, diverse_beam_groups=-1, diverse_beam_strength=0.5, empty_cache_freq=0, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, lazy_load=False, left_pad_source='True', left_pad_target='False', lenpen=1, load_alignments=False, log_format=None, log_interval=1000, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_sentences=32, max_source_positions=1024, max_target_positions=1024, max_tokens=None, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', momentum=0.99, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, num_shards=1, num_workers=1, optimizer='nag', path='data-bin/wmt19.en-de.joined-dict.single_model/model.pt', prefix_size=0, print_alignment=False, print_step=False, quiet=False, raw_text=False, remove_bpe='@@ ', replace_unk=None, required_batch_size_multiple=8, results_path=None, retain_iter_history=False, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', temperature=1.0, tensorboard_logdir='', threshold_loss_scale=None, tokenizer=None, truncate_source=False, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, warmup_updates=0, weight_decay=0.0)
| [en] dictionary: 42024 types
| [de] dictionary: 42024 types
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.en
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.de
| data-bin/wmt19.en-de.joined-dict.single_model test en-de 1997 examples
| loading model(s) from data-bin/wmt19.en-de.joined-dict.single_model/model.pt
Traceback (most recent call last):
  File "/usr/local/bin/fairseq-generate", line 8, in <module>
    sys.exit(cli_main())
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 199, in cli_main
    main(args)
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 104, in main
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/tasks/fairseq_task.py", line 265, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 113, in generate
    return self._generate(model, sample, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 379, in _generate
    scores.view(bsz, beam_size, -1)[:, :, :step],
  File "/usr/local/lib/python3.6/dist-packages/fairseq/search.py", line 81, in step
    torch.div(self.indices_buf, vocab_size, out=self.beams_buf)
RuntimeError: Integer division of tensors using div or / is no longer supported, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.

When I change the file into true_division I get.

tcmalloc: large alloc 1078984704 bytes == 0x508ba000 @  0x7f8056742b6b 0x7f8056762379 0x7f800aecf92e 0x7f800aed1946 0x7f80432aca05 0x7f8052983b49 0x551755 0x5a9eec 0x50a783 0x50c1f4 0x507f24 0x509202 0x5a4d81 0x5a50d8 0x4e01be 0x50a7b1 0x50c1f4 0x507f24 0x588fac 0x59fe1e 0x50d596 0x507f24 0x509c50 0x50a64d 0x50cfd6 0x507f24 0x509c50 0x50a64d 0x50c1f4 0x507f24 0x509c50
  0% 0/250 [00:00<?, ?it/s]Namespace(beam=5, bpe=None, cpu=False, criterion='cross_entropy', data='data-bin/wmt19.en-de.joined-dict.single_model', dataset_impl='raw', decoding_format=None, diverse_beam_groups=-1, diverse_beam_strength=0.5, empty_cache_freq=0, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, lazy_load=False, left_pad_source='True', left_pad_target='False', lenpen=1, load_alignments=False, log_format=None, log_interval=1000, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_sentences=8, max_source_positions=1024, max_target_positions=1024, max_tokens=None, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', momentum=0.99, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, num_shards=1, num_workers=1, optimizer='nag', path='data-bin/wmt19.en-de.joined-dict.single_model/model.pt', prefix_size=0, print_alignment=False, print_step=False, quiet=False, raw_text=False, remove_bpe='@@ ', replace_unk=None, required_batch_size_multiple=8, results_path=None, retain_iter_history=False, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', temperature=1.0, tensorboard_logdir='', threshold_loss_scale=None, tokenizer=None, truncate_source=False, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, warmup_updates=0, weight_decay=0.0)
| [en] dictionary: 42024 types
| [de] dictionary: 42024 types
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.en
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.de
| data-bin/wmt19.en-de.joined-dict.single_model test en-de 1997 examples
| loading model(s) from data-bin/wmt19.en-de.joined-dict.single_model/model.pt
Traceback (most recent call last):
  File "/usr/local/bin/fairseq-generate", line 8, in <module>
    sys.exit(cli_main())
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 199, in cli_main
    main(args)
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 104, in main
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/tasks/fairseq_task.py", line 265, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 113, in generate
    return self._generate(model, sample, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 379, in _generate
    scores.view(bsz, beam_size, -1)[:, :, :step],
  File "/usr/local/lib/python3.6/dist-packages/fairseq/search.py", line 81, in step
    torch.true_divide(self.indices_buf, vocab_size, out=self.beams_buf)
RuntimeError: result type Float can't be cast to the desired output type Long

do anyone know how to resolve the issue?

Python 3.6, Torch 1.6

stas00 commented 4 years ago

Use floor_divide and then send a PR so it's fixed for all of us ;)

Thank you!

babangain commented 4 years ago

It is already fixed in master. However, when installed via pip, it's still the old one. How to fix/PR that?

stas00 commented 4 years ago

Are you installing from master?

pip install -e git+https://github.com/pytorch/fairseq/
stas00 commented 4 years ago

Alternatively, if you are continually working with fairseq and want to be up-to-date

git clone https://github.com/pytorch/fairseq/
cd fairseq
pip install -e .

so now this checkout is the code your python will run, so whenever you do:

git pull

in that folder, it'll automatically get you the latest source and you don't need to do anything else.

babangain commented 4 years ago

Thanks. It was working. I though of to PR it. Which is already done by someone else. Earlier I installed via pip install fairseq. That's why it was like that.

I think most people install via pip install fairseq only. so, any update to pypi fairseq would be great.

stas00 commented 4 years ago

I agree. Please open a separate issue and indicate why a new pypi fairseq release is wanted.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

zgm1ybq commented 2 years ago

有人知道如何解决这个问题吗?

I also have this problem, how can I solve it

I am trying to reproduce the same. except for en-de

When I run this command

!fairseq-generate data-bin/wmt19.en-de.joined-dict.single_model --path data-bin/wmt19.en-de.joined-dict.single_model/model.pt --beam 5 --batch-size 32 --remove-bpe --source-lang en --target-lang de \
--task translation --dataset-impl raw | tee /tmp/gen.out.models

I get,

tcmalloc: large alloc 1078984704 bytes == 0x4ee48000 @  0x7f4f7a80db6b 0x7f4f7a82d379 0x7f4f2ef9a92e 0x7f4f2ef9c946 0x7f4f67377a05 0x7f4f76a4eb49 0x551755 0x5a9eec 0x50a783 0x50c1f4 0x507f24 0x509202 0x5a4d81 0x5a50d8 0x4e01be 0x50a7b1 0x50c1f4 0x507f24 0x588fac 0x59fe1e 0x50d596 0x507f24 0x509c50 0x50a64d 0x50cfd6 0x507f24 0x509c50 0x50a64d 0x50c1f4 0x507f24 0x509c50
  0% 0/63 [00:00<?, ?it/s]Namespace(beam=5, bpe=None, cpu=False, criterion='cross_entropy', data='data-bin/wmt19.en-de.joined-dict.single_model', dataset_impl='raw', decoding_format=None, diverse_beam_groups=-1, diverse_beam_strength=0.5, empty_cache_freq=0, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, lazy_load=False, left_pad_source='True', left_pad_target='False', lenpen=1, load_alignments=False, log_format=None, log_interval=1000, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_sentences=32, max_source_positions=1024, max_target_positions=1024, max_tokens=None, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', momentum=0.99, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, num_shards=1, num_workers=1, optimizer='nag', path='data-bin/wmt19.en-de.joined-dict.single_model/model.pt', prefix_size=0, print_alignment=False, print_step=False, quiet=False, raw_text=False, remove_bpe='@@ ', replace_unk=None, required_batch_size_multiple=8, results_path=None, retain_iter_history=False, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', temperature=1.0, tensorboard_logdir='', threshold_loss_scale=None, tokenizer=None, truncate_source=False, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, warmup_updates=0, weight_decay=0.0)
| [en] dictionary: 42024 types
| [de] dictionary: 42024 types
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.en
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.de
| data-bin/wmt19.en-de.joined-dict.single_model test en-de 1997 examples
| loading model(s) from data-bin/wmt19.en-de.joined-dict.single_model/model.pt
Traceback (most recent call last):
  File "/usr/local/bin/fairseq-generate", line 8, in <module>
    sys.exit(cli_main())
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 199, in cli_main
    main(args)
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 104, in main
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/tasks/fairseq_task.py", line 265, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 113, in generate
    return self._generate(model, sample, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 379, in _generate
    scores.view(bsz, beam_size, -1)[:, :, :step],
  File "/usr/local/lib/python3.6/dist-packages/fairseq/search.py", line 81, in step
    torch.div(self.indices_buf, vocab_size, out=self.beams_buf)
RuntimeError: Integer division of tensors using div or / is no longer supported, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.

When I change the file into true_division I get.

tcmalloc: large alloc 1078984704 bytes == 0x508ba000 @  0x7f8056742b6b 0x7f8056762379 0x7f800aecf92e 0x7f800aed1946 0x7f80432aca05 0x7f8052983b49 0x551755 0x5a9eec 0x50a783 0x50c1f4 0x507f24 0x509202 0x5a4d81 0x5a50d8 0x4e01be 0x50a7b1 0x50c1f4 0x507f24 0x588fac 0x59fe1e 0x50d596 0x507f24 0x509c50 0x50a64d 0x50cfd6 0x507f24 0x509c50 0x50a64d 0x50c1f4 0x507f24 0x509c50
  0% 0/250 [00:00<?, ?it/s]Namespace(beam=5, bpe=None, cpu=False, criterion='cross_entropy', data='data-bin/wmt19.en-de.joined-dict.single_model', dataset_impl='raw', decoding_format=None, diverse_beam_groups=-1, diverse_beam_strength=0.5, empty_cache_freq=0, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, lazy_load=False, left_pad_source='True', left_pad_target='False', lenpen=1, load_alignments=False, log_format=None, log_interval=1000, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_sentences=8, max_source_positions=1024, max_target_positions=1024, max_tokens=None, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', momentum=0.99, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, num_shards=1, num_workers=1, optimizer='nag', path='data-bin/wmt19.en-de.joined-dict.single_model/model.pt', prefix_size=0, print_alignment=False, print_step=False, quiet=False, raw_text=False, remove_bpe='@@ ', replace_unk=None, required_batch_size_multiple=8, results_path=None, retain_iter_history=False, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', temperature=1.0, tensorboard_logdir='', threshold_loss_scale=None, tokenizer=None, truncate_source=False, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, warmup_updates=0, weight_decay=0.0)
| [en] dictionary: 42024 types
| [de] dictionary: 42024 types
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.en
| loaded 1997 examples from: data-bin/wmt19.en-de.joined-dict.single_model/test.en-de.de
| data-bin/wmt19.en-de.joined-dict.single_model test en-de 1997 examples
| loading model(s) from data-bin/wmt19.en-de.joined-dict.single_model/model.pt
Traceback (most recent call last):
  File "/usr/local/bin/fairseq-generate", line 8, in <module>
    sys.exit(cli_main())
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 199, in cli_main
    main(args)
  File "/usr/local/lib/python3.6/dist-packages/fairseq_cli/generate.py", line 104, in main
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/tasks/fairseq_task.py", line 265, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 113, in generate
    return self._generate(model, sample, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/fairseq/sequence_generator.py", line 379, in _generate
    scores.view(bsz, beam_size, -1)[:, :, :step],
  File "/usr/local/lib/python3.6/dist-packages/fairseq/search.py", line 81, in step
    torch.true_divide(self.indices_buf, vocab_size, out=self.beams_buf)
RuntimeError: result type Float can't be cast to the desired output type Long

do anyone know how to resolve the issue?

Python 3.6, Torch 1.6

I also have this problem, how can I solve it

stale[bot] commented 2 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!

stale[bot] commented 2 years ago

Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!