facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

KeyError: 'online' during generation in Simultaneous Translation task #3472

Closed VrutikShah closed 3 years ago

VrutikShah commented 3 years ago

❓ Questions and Help

I am trying to run a simultaneous translation task using the following training command:

Code

fairseq-train \ data-bin/ \ --simul-type hard_aligned \ --user-dir $FAIRSEQ/examples/simultaneous_translation \ --mass-preservation \ --criterion label_smoothed_cross_entropy_with_alignment \ --max-epoch 30 \ --arch transformer_monotonic_iwslt_de_en \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --warmup-init-lr 1e-7 --warmup-updates 4000 \ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 \

The training works fine. It saves the checkpoints under the 'checkpoints' folder For generation, I am running the following command: `fairseq-generate data-bin/ \ --path checkpoints/checkpoint_best.pt \ --beam 5 --batch-size 64 \` I am getting the following error ```console Traceback (most recent call last): File "/home/vrutik.shah/anaconda3/envs/dual_env/bin/fairseq-generate", line 33, in sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-generate')()) File "/home/vrutik.shah/fairseq/fairseq_cli/generate.py", line 404, in cli_main main(args) File "/home/vrutik.shah/fairseq/fairseq_cli/generate.py", line 49, in main return _main(cfg, sys.stdout) File "/home/vrutik.shah/fairseq/fairseq_cli/generate.py", line 201, in _main hypos = task.inference_step( File "/home/vrutik.shah/fairseq/fairseq/tasks/fairseq_task.py", line 500, in inference_step return generator.generate( File "/home/vrutik.shah/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(*args, **kwargs) File "/home/vrutik.shah/fairseq/fairseq/sequence_generator.py", line 186, in generate return self._generate(sample, **kwargs) File "/home/vrutik.shah/fairseq/fairseq/sequence_generator.py", line 323, in _generate lprobs, avg_attn_scores = self.model.forward_decoder( File "/home/vrutik.shah/fairseq/fairseq/sequence_generator.py", line 777, in forward_decoder decoder_out = model.decoder.forward( File "/home/vrutik.shah/fairseq/fairseq/models/transformer.py", line 817, in forward x, extra = self.extract_features( File "/home/vrutik.shah/fairseq/examples/simultaneous_translation/models/transformer_monotonic_attention.py", line 235, in extract_features if_online = incremental_state["online"]["only"] KeyError: 'online' ``` Am I missing some arguments in the generate command? Or is this error due to something else? #### What's your environment? - fairseq Version : master - How you installed fairseq (`pip`, source): source - Build command you used (if compiling from source): pip install --editable (inside the cloned fairseq repo folder)
krishnamrith12 commented 3 years ago

Have you tried using simulEval ? https://github.com/pytorch/fairseq/blob/master/examples/simultaneous_translation/docs/enja-waitk.md