facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.19k stars 6.37k forks source link

Non-autoregressive models miss prev_output_tokens argument #3469

Closed speedcell4 closed 3 years ago

speedcell4 commented 3 years ago

🐛 Bug

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

I got this error when I try to evaluate a trained non-autoregressive MT model.

➜  datasets fairseq-generate data-bin/wmt14_en_de_distill --path checkpoints/crf-nat/210413-135213-05d984c5/checkpoint1.pt
2021-04-13 14:06:01 | INFO | fairseq_cli.generate | Namespace(all_gather_list_size=16384, batch_size=None, batch_size_valid=None, beam=5, bf16=False, bpe=None, broadcast_buffers=False, bucket_cap_mb=25, checkpoint_shard_count=1, checkpoint_suffix='', constraints=None, cpu=False, criterion='cross_entropy', curriculum=0, data='data-bin/wmt14_en_de_distill', data_buffer_size=10, dataset_impl=None, ddp_backend='c10d', decoding_format=None, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_port=-1, distributed_rank=0, distributed_world_size=1, distributed_wrapper='DDP', diverse_beam_groups=-1, diverse_beam_strength=0.5, diversity_rate=-1.0, empty_cache_freq=0, eval_bleu=False, eval_bleu_args=None, eval_bleu_detok='space', eval_bleu_detok_args=None, eval_bleu_print_samples=False, eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, fast_stat_sync=False, find_unused_parameters=False, fix_batches_to_gpus=False, fixed_validation_seed=None, force_anneal=None, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, gen_subset='test', iter_decode_eos_penalty=0.0, iter_decode_force_max_iter=False, iter_decode_max_iter=10, iter_decode_with_beam=1, iter_decode_with_external_reranker=False, left_pad_source='True', left_pad_target='False', lenpen=1, lm_path=None, lm_weight=0.0, load_alignments=False, localsgd_frequency=3, log_format=None, log_interval=100, lr_scheduler='fixed', lr_shrink=0.1, match_source_len=False, max_len_a=0, max_len_b=200, max_source_positions=1024, max_target_positions=1024, max_tokens=12000, max_tokens_valid=None, memory_efficient_bf16=False, memory_efficient_fp16=False, min_len=1, min_loss_scale=0.0001, model_overrides='{}', model_parallel_size=1, nbest=1, no_beamable_mm=False, no_early_stop=False, no_progress_bar=False, no_repeat_ngram_size=0, no_seed_provided=False, nprocs_per_node=8, num_batch_buckets=0, num_shards=1, num_workers=1, optimizer=None, path='checkpoints/crf-nat/210413-135213-05d984c5/checkpoint1.pt', pipeline_balance=None, pipeline_checkpoint='never', pipeline_chunks=0, pipeline_decoder_balance=None, pipeline_decoder_devices=None, pipeline_devices=None, pipeline_encoder_balance=None, pipeline_encoder_devices=None, pipeline_model_parallel=False, prefix_size=0, print_alignment=False, print_step=False, profile=False, quantization_config_path=None, quiet=False, remove_bpe=None, replace_unk=None, required_batch_size_multiple=8, required_seq_len_multiple=1, results_path=None, retain_dropout=False, retain_dropout_modules=None, retain_iter_history=False, sacrebleu=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, score_reference=False, scoring='bleu', seed=1, shard_id=0, skip_invalid_size_inputs_valid_test=False, slowmo_algorithm='LocalSGD', slowmo_momentum=None, source_lang=None, target_lang=None, task='translation', temperature=1.0, tensorboard_logdir=None, threshold_loss_scale=None, tokenizer=None, tpu=False, train_subset='train', truncate_source=False, unkpen=0, unnormalized=False, upsample_primary=1, user_dir=None, valid_subset='valid', validate_after_updates=0, validate_interval=1, validate_interval_updates=0, warmup_updates=0, zero_sharding='none')
2021-04-13 14:06:01 | INFO | fairseq.tasks.translation | [en] dictionary: 39840 types
2021-04-13 14:06:01 | INFO | fairseq.tasks.translation | [de] dictionary: 39840 types
2021-04-13 14:06:01 | INFO | fairseq.data.data_utils | loaded 3003 examples from: data-bin/wmt14_en_de_distill/test.en-de.en
2021-04-13 14:06:02 | INFO | fairseq.data.data_utils | loaded 3003 examples from: data-bin/wmt14_en_de_distill/test.en-de.de
2021-04-13 14:06:02 | INFO | fairseq.tasks.translation | data-bin/wmt14_en_de_distill test en-de 3003 examples
2021-04-13 14:06:02 | INFO | fairseq_cli.generate | loading model(s) from checkpoints/crf-nat/210413-135213-05d984c5/checkpoint1.pt
Traceback (most recent call last):
  File "~/miniconda3/bin/fairseq-generate", line 8, in <module>
    sys.exit(cli_main())
  File "~/miniconda3/lib/python3.8/site-packages/fairseq_cli/generate.py", line 379, in cli_main
    main(args)
  File "~/miniconda3/lib/python3.8/site-packages/fairseq_cli/generate.py", line 41, in main
    return _main(args, sys.stdout)
  File "~/miniconda3/lib/python3.8/site-packages/fairseq_cli/generate.py", line 191, in _main
    hypos = task.inference_step(
  File "~/miniconda3/lib/python3.8/site-packages/fairseq/tasks/fairseq_task.py", line 433, in inference_step
    return generator.generate(
  File "~/miniconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "~/miniconda3/lib/python3.8/site-packages/fairseq/sequence_generator.py", line 177, in generate
    return self._generate(sample, **kwargs)
  File "~/miniconda3/lib/python3.8/site-packages/fairseq/sequence_generator.py", line 312, in _generate
    lprobs, avg_attn_scores = self.model.forward_decoder(
  File "~/miniconda3/lib/python3.8/site-packages/fairseq/sequence_generator.py", line 824, in forward_decoder
    decoder_out = model.decoder.forward(
  File "~/miniconda3/lib/python3.8/site-packages/fairseq/models/nat/fairseq_nat_model.py", line 40, in wrapper
    return func(
TypeError: forward() missing 1 required positional argument: 'prev_output_tokens'

the training command is as following,

$HOME/miniconda3/bin/fairseq-train \
    $HOME/datasets/data-bin/wmt14_en_de_distill \
    --save-dir $HOME/datasets/checkpoints/$STUDY/$TRIAL \
    --ddp-backend=c10d \
    --task translation_lev \
    --criterion nat_loss \
    --arch nacrf_transformer \
    --noise full_mask \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr 0.0005 --lr-scheduler inverse_sqrt \
    --min-lr '1e-09' --warmup-updates 10000 \
    --warmup-init-lr '1e-07' --label-smoothing 0.1 \
    --dropout 0.3 --weight-decay 0.01 \
    --decoder-learned-pos \
    --encoder-learned-pos \
    --pred-length-offset \
    --length-loss-factor 0.1 \
    --word-ins-loss-factor 0.5 \
    --crf-lowrank-approx 32 \
    --crf-beam-approx 64 \
    --apply-bert-init \
    --log-format 'simple' --log-interval 100 \
    --fixed-validation-seed 7 \
    --max-tokens 8000 \
    --save-interval-updates 10000 \
    --max-update 300000 \
    --fp16 \
    --share-all-embeddings &> $HOME/datasets/checkpoints/$STUDY/$TRIAL/log.txt

Expected behavior

Environment

speedcell4 commented 3 years ago

Problem solved by using the following command

fairseq-generate data-bin/wmt14_en_de_distill --path checkpoints/crf-nat/210413-135213-05d984c5/checkpoint_best.pt --gen-subset test --task translation_lev --iter-decode-max-iter 1 --iter-decode-eos-penalty 0 --beam 1 --remove-bpe --print-step --batch-size 400