microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
19.64k stars 2.51k forks source link

[GAD] Error when trying to use Generalized Aggressive Decoding for inference #768

Open ndkmath1 opened 2 years ago

ndkmath1 commented 2 years ago

I'm trying to use: https://github.com/microsoft/unilm/tree/master/decoding/GAD with pre-trained model in that repo for inference

Description: wmt14.en-de Models:

My inference script

#!/bin/bash

data_dir=./unilm/decoding/GAD/data/wmt14.en-de # the dir that contains dict files
checkpoint_path=./wmt14-en-de-base-at-verifier.pt # the dir that contains AT verifier checkpoint
AR_checkpoint_path=./wmt14-en-de-base-nat-drafter-checkpoint.avg10.pt # the dir that contains NAT drafter checkpoint
input_path=./unilm/decoding/GAD/test.bpe # the dir that contains bpe test files
output_path=./unilm/decoding/GAD/output.txt # the dir for outputs

strategy='gad' # fairseq, AR, gad
batch=1
beam=5

beta=5
tau=3.0
block_size=25

src=en
tgt=de

python3 inference.py ${data_dir} --path ${checkpoint_path} \
      --user-dir block_plugins --task translation_lev_modified --remove-bpe --max-sentences 20 --source-lang ${src} \
      --target-lang ${tgt} --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --iter-decode-with-beam 1 \
      --gen-subset test --AR-path ${AR_checkpoint_path} --input-path ${input_path} --output-path ${output_path} \
      --block-size ${block_size} --beta ${beta} --tau ${tau} --batch ${batch} --beam ${beam} --strategy ${strategy}

Here's the content of ./unilm/decoding/GAD/test.bpe Ro@@ asted barr@@ am@@ un@@ di fish

I'm facing this error when trying to run that script

GAD plugins loaded...
2022-06-24 03:44:45 | INFO | fairseq.tasks.translation |  [translation.py:313] [en] dictionary: 32768 types
2022-06-24 03:44:45 | INFO | fairseq.tasks.translation |  [translation.py:314] [de] dictionary: 32768 types
2022-06-24 03:44:45 | INFO | inference |  [inference.py:265] loading model(s) from /home/user/wmt14-en-de-base-at-verifier.pt
Traceback (most recent call last):
  File "inference.py", line 282, in <module>
    AR_models, _AR_model_args, _AR_model_task = load_model_ensemble_and_task(filenames=[cmd_args.AR_path],
  File "/home/user/unilm/decoding/GAD/fairseq/checkpoint_utils.py", line 350, in load_model_ensemble_and_task
    task = tasks.setup_task(cfg.task)
  File "/home/user/unilm/decoding/GAD/fairseq/tasks/__init__.py", line 44, in setup_task
    return task.setup_task(cfg, **kwargs)
  File "/home/user/unilm/decoding/GAD/fairseq/tasks/translation.py", line 299, in setup_task
    raise Exception(
Exception: Could not infer language pair, please provide it explicitly

Please help to answer the following questions

  1. What should be the format of the file in input_path? Is my bpe input correct?
  2. How to fix the error?

Thank you.

hemingkx commented 2 years ago

Hi @ndkmath1, I noticed that you put the NAT drafter and the AT verifier in opposite positions. They should be placed as follows. Can you try again the correct way?

checkpoint_path=./checkpoints/wmt14-en-de-base-nat-drafter-checkpoint.avg10.pt # the dir that contains NAT drafter checkpoint
AR_checkpoint_path=./checkpoints/wmt14-en-de-base-at-verifier.pt # the dir that contains AT verifier checkpoint
hemingkx commented 2 years ago

You can pre-process the official WMT14 EN-DE test.en file with our bpemodel here.

Here are the first three lines of our test file:

Gut@@ ach : Incre@@ ased safety for pedestri@@ ans
They are not even 100 metres apart : On Tuesday , the new B 33 pedestrian lights in Dorf@@ park@@ platz in Gut@@ ach became operational - within view of the existing Town Hall traffic lights .
Two sets of lights so close to one another : inten@@ tional or just a sil@@ ly error ?
ndkmath1 commented 2 years ago

Hi @hemingkx

I noticed that you put the NAT drafter and the AT verifier in opposite positions. They should be placed as follows. Can you try again the correct way?

It works. Thank you!

I'm trying to train GAD for a different language pair. Could you please let me know how to create your bpe model/vocab in this folder from the official WMT14 EN-DE dataset? Is there any reference script/code? So I can apply the same way for the different language pair.

hemingkx commented 2 years ago

The bpe codes we used are from the released data by Google. However, you can find a pre-process script here, which works fine with wmt14.en-de raw data.

hemingkx commented 2 years ago

If you encounter any problems, feel free to contact me 😊

ndkmath1 commented 2 years ago

Thank you! I will contact you if there is any problems.

ndkmath1 commented 2 years ago

Hi @hemingkx,

Now, I can pre-process data based on your suggestion. And I can train NAT drafter of GAD based on the script.

I don't see how to train AT verifier in readme.md. Could you please let me know how to train AT verifier?

Thank you.

hemingkx commented 2 years ago

Hi @ndkmath1, The training of the AT verifier is totally the same as the conventional Transformer in fairseq, just check the scripts below:

  1. fairseq/training-a-new-model
  2. scalingnmt/training-a-new-model-on-wmt16-en-de
ndkmath1 commented 2 years ago

Thank you. I will try the links.

ndkmath1 commented 2 years ago

Hi @hemingkx,

Now, I'm able to train both NAT drafter and AT verifier based on your suggestion. I'm facing the following error when trying to run inference.

~/xfspell$ sh src/03_inference.sh 
GAD plugins loaded...
2022-07-01 10:50:07 | INFO | fairseq.tasks.translation |  [translation.py:313] [fr] dictionary: 80 types
2022-07-01 10:50:07 | INFO | fairseq.tasks.translation |  [translation.py:314] [en] dictionary: 80 types
2022-07-01 10:50:07 | INFO | inference |  [inference.py:265] loading model(s) from models/nat_drafter01/checkpoint_best.pt
Traceback (most recent call last):
  File "/home/user/unilm/decoding/GAD/inference.py", line 282, in <module>
    AR_models, _AR_model_args, _AR_model_task = load_model_ensemble_and_task(filenames=[cmd_args.AR_path],
  File "/home/user/unilm/decoding/GAD/fairseq/checkpoint_utils.py", line 339, in load_model_ensemble_and_task
    state = load_checkpoint_to_cpu(filename, arg_overrides)
  File "/home/user/unilm/decoding/GAD/fairseq/checkpoint_utils.py", line 271, in load_checkpoint_to_cpu
    overwrite_args_by_name(state["cfg"], arg_overrides)
  File "/home/user/unilm/decoding/GAD/fairseq/dataclass/utils.py", line 427, in overwrite_args_by_name
    with open_dict(cfg):
  File "/usr/lib/python3.8/contextlib.py", line 113, in __enter__
    return next(self.gen)
  File "/home/user/.local/lib/python3.8/site-packages/omegaconf/omegaconf.py", line 669, in open_dict
    prev_state = config._get_node_flag("struct")
AttributeError: 'dict' object has no attribute '_get_node_flag'

How to reproduce the error? Please use this repo. Steps:

  1. Preprocess data: ~/xfspell$ sh src/00_preprocess_data.sh
  2. Train AT verifier: ~/xfspell$ sh src/01_train_AT_verifier.sh
  3. Train NAT drafter: ~/xfspell$ sh src/02_train_NAT_drafter.sh
  4. Run inference: ~/xfspell$ sh src/03_inference.sh

Do you have any suggestion to fix that error?

Thank you.

hemingkx commented 2 years ago

Maybe this issue helps. Do you also have this problem when just running our code?

ndkmath1 commented 2 years ago

Hi @hemingkx,

Thank you for the link to the issue. I will try it.

Do you also have this problem when just running our code?

I've installed fairseq as same as the instruction in readme.md

cd GAD
pip install --editable .

Step 1 (Preprocess data). I use a bit different way (please refer to my comment above) from your suggestion but I'm able to train normal transformer or edgelm based on that binarize data without any problems (the inference also works well).

Step 2 (Train AT verifier). It is the same as the transformer models here.

Step 3 and 4 (Train NAT drafter and run inference) are the same as your code. I just copy it into a separate shell script file.

I think the main difference is in Step 1 (Preproces data).

hemingkx commented 2 years ago
  File "/home/user/.local/lib/python3.8/site-packages/omegaconf/omegaconf.py", line 669, in open_dict
    prev_state = config._get_node_flag("struct")
AttributeError: 'dict' object has no attribute '_get_node_flag'

emmm it seems that the error comes from omegaconf. Could you try to update the omegaconf version in your environment and then try running the script?

BTW, please check this issue and this. Following their solutions might solve your problem😊.

ndkmath1 commented 2 years ago

Thank you. I will check your suggestions.

MLDeep16 commented 1 year ago

  File "/home/user/.local/lib/python3.8/site-packages/omegaconf/omegaconf.py", line 669, in open_dict
    prev_state = config._get_node_flag("struct")
AttributeError: 'dict' object has no attribute '_get_node_flag'
@hemingkx @ndkmath1  I am facing similar issue, were you able to resolve the issue, any help regarding the issue?