facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.49k stars 6.41k forks source link

mBART50 Translation/Fine Tuning with Many-to-One Model not working #3474

Open sergej-d opened 3 years ago

sergej-d commented 3 years ago

Hello, I think there is something wrong with the arch. If I get the args for the mbart50.ft.n1 model it says "arch": "denoising_large". But denoising_large is not available in Fairseq as I see.

To Reproduce:

path_2_data=/home/sergej/fairseq/data4translation/ model=/home/sergej/fairseq/mbart50.ft.n1/model.pt lang_list=mbart50.ft.n1/ML50_langs.txt source_lang=de_DE lang_pairs=de_DE-en_XX target_lang=en_XX

fairseq-generate $path_2_data \ --path $model \ --task translation_multi_simple_epoch \ --gen-subset test \ --source-lang $source_lang \ --target-lang $target_lang \ --sacrebleu \ --remove-bpe 'sentencepiece'\ --batch-size 32 \ --encoder-langtok "src" \ --decoder-langtok \ --lang-dict "$lang_list" \ --lang-pairs "$lang_pairs" > ${sourcelang}${target_lang}_mBART50FTN1_on_KWS_Test.txt

Error:

Traceback (most recent call last): File "/home/sergej/nlp/bin/fairseq-generate", line 33, in sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-generate')()) File "/home/sergej/fairseq/fairseq_cli/generate.py", line 404, in cli_main main(args) File "/home/sergej/fairseq/fairseq_cli/generate.py", line 49, in main return _main(cfg, sys.stdout) File "/home/sergej/fairseq/fairseq_cli/generate.py", line 96, in _main models, saved_cfg = checkpoint_utils.load_model_ensemble( File "/home/sergej/fairseq/fairseq/checkpoint_utils.py", line 319, in load_model_ensemble ensemble, args, _task = load_model_ensemble_and_task( File "/home/sergej/fairseq/fairseq/checkpoint_utils.py", line 361, in load_model_ensemble_and_task state = load_checkpoint_to_cpu(filename, arg_overrides) File "/home/sergej/fairseq/fairseq/checkpoint_utils.py", line 295, in load_checkpoint_to_cpu state = _upgrade_state_dict(state) File "/home/sergej/fairseq/fairseq/checkpoint_utils.py", line 537, in _upgrade_state_dict state["cfg"] = convert_namespace_to_omegaconf(state["args"]) File "/home/sergej/fairseq/fairseq/dataclass/utils.py", line 386, in convert_namespace_to_omegaconf _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) KeyError: 'denoising_large'

sukuya commented 3 years ago

Facing the same issue for finetuning many2one.

alexandra-chron commented 3 years ago

A quick-but-dirty workaround is to go to fairseq/dataclass/utils.py and manually set args.arch = 'mbart_large' after line 380 (from fairseq.models import ARCH_MODEL_REGISTRY).

joelb-git commented 3 years ago

For fine-tuning the many-to-one model, I've found that I additionally need to use --arch mbart_large in the fairseq_train command with --task translation_multi_simple_epoch. The current docs say to use --arch transformer in 6f6f704d10 of the README:

https://github.com/pytorch/fairseq/blob/master/examples/multilingual/README.md

Failure to do this results in:

2021-05-19 11:29:45 | INFO | fairseq.trainer | Preparing to load checkpoint models/mbart50.ft.n1/model.pt
Traceback (most recent call last):
  File "fairseq/fairseq/trainer.py", line 460, in load_checkpoint
    self.model.load_state_dict(
  File "fairseq/fairseq/models/fairseq_model.py", line 125, in load_state_dict
    return super().load_state_dict(new_state_dict, strict)
  File "fairseq/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BARTModel:
        Missing key(s) in state_dict: "decoder.output_projection.weight".
        Unexpected key(s) in state_dict: "encoder.embed_positions.weight", "encoder.layers.6.self_attn.k_proj.weight",
    ...
        size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([250053, 1024]) from checkpoint, the shape in current model is torch.Size([250053, 512]).
        size mismatch for encoder.layernorm_embedding.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
        ...
Exception: Cannot load model parameters from checkpoint models/mbart50.ft.n1/model.pt; please ensure that the architectures match.
gegallego commented 3 years ago

Hi! I've also faced this issue. I manually modified the args.arch in the checkpoint, from denoising_large to mbart_large, and it works. I've just loaded it with torch.load(PATH) and, once modified, I saved it again with torch.save(ckpt_dict, PATH). It would be great if someone at Facebook could update this checkpoint.

Remorax commented 3 years ago

I have tried all the approaches in this thread (including @gegallego, @joelb-git, @alexandra-chron) and none of them work for me. I still get dimension mismatch errors of this sort:

copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]) (Similar to what @joelb-git mentions)

Is there anything else I could try? Any help would be greatly appreciated. Thanks!

joelb-git commented 3 years ago

@Remorax - not sure this will help you, but here are some more specifics of what I did.

Here is the commit I used:

$ git log | head -1
commit a4e1d4a3daf4f6f5557505026fd94b8716fba7b3

The code changes:

$ git diff
diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py
index 27c9006..58be37f 100644
--- a/fairseq/dataclass/utils.py
+++ b/fairseq/dataclass/utils.py
@@ -383,6 +383,9 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
         cfg.model = Namespace(**vars(args))
         from fairseq.models import ARCH_MODEL_REGISTRY

+        # https://github.com/pytorch/fairseq/issues/3474
+        args.arch = 'mbart_large'
+
         _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
         cfg.model._name = args.arch
     if cfg.optimizer is None and getattr(args, "optimizer", None):
diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py
index 0fd7a5b..28d7c13 100644
--- a/fairseq/tasks/translation_from_pretrained_bart.py
+++ b/fairseq/tasks/translation_from_pretrained_bart.py
@@ -51,6 +51,8 @@ class TranslationFromPretrainedBARTTask(TranslationTask):

     def __init__(self, args, src_dict, tgt_dict):
         super().__init__(args, src_dict, tgt_dict)
+        # https://github.com/pytorch/fairseq/issues/3169
+        #self.args = args
         self.langs = args.langs.split(",")
         for d in [src_dict, tgt_dict]:
             for l in self.langs:

A snippet from my fine-tuning script:

mbart_checkpoint=models/mbart50.ft.n1/model.pt
total_num_update=50000
lr=3e-05
warmup_updates=3000

fairseq-train $data_dir --save-dir $save_dir \
  --finetune-from-model $mbart_checkpoint \
  --encoder-normalize-before --decoder-normalize-before \
  --arch mbart_large --layernorm-embedding \
  --task translation_multi_simple_epoch \
  --sampling-method "temperature" \
  --sampling-temperature 1.5 \
  --encoder-langtok "src" \
  --decoder-langtok \
  --lang-dict mbart50.pretrained/ML50_langs.txt \
  --lang-pairs ka_GE-en_XX \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
  --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
  --lr-scheduler inverse_sqrt --lr $lr --warmup-updates $warmup_updates --max-update $total_num_update \
  --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
  --max-tokens 1024 --update-freq 2 \
  --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
  --seed 222 --log-format simple --log-interval 2
Remorax commented 3 years ago

Thank you so much for providing the snippets, @joelb-git! I tried to replicate them exactly, by cloning the mentioned commit and reproducing your changes. The issue I mentioned got resolved, but I have got a new error now:

RuntimeError: Error(s) in loading state_dict for BARTModel

        size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([250054, 1024]) from checkpoint, the shape in current model is torch.Size([250003, 1024]).
        size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([250054, 1024]) from checkpoint, the shape in current model is torch.Size([250003, 1024]).
        size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([250054, 1024]) from checkpoint, the shape in current model is torch.Size([250003, 1024]).

It appears that the matrices are larger in the updated version of mBART. I have downloaded the pretrained model from this link. Could you share the link to the version you used, or even better, tell me how I can update the 250003 expected in the model configuration to 250054?

Thanks again!

joelb-git commented 3 years ago

@Remorax I just confirmed I can still run my fine-tuning script. I also double checked the model I started with and it is the same as yours: https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.n1.tar.gz (I re-downloaded and ran cmp to be completely sure.)

I wonder two things: 1) when you sync'ed back to the earlier commit, did you remember to re-run pip install --editable ./ 2) did you reprocess your data with code active at the old commit?

Remorax commented 3 years ago

@joelb-git thank you so much for going to the effort of downloading the huge mBART binary file, and running cmp on it to verify its the same.

  1. Yes, I did this. I ran pip install --user --editable ./.
  2. No, I re-ran all the scripts for pre-processing. Namely spm tokenization using the spm_encode script, fairseq_preprocess and finally fairseq_train. Still facing the same error described in my previous reply.

Do you have any other suggestions or things I could try?

gegallego commented 3 years ago

Hi @Remorax! I think I faced a similar issue at the beginning of my experiments with mBART50. It seems like the 250003 comes from the 249997 tokens in dict.en_XX.txt plus the 4 special tokens (<s>, <pad>, </s>, <unk>). However, mBART50 was trained with 53 extra special tokens (the 52 languages in ML50_langs.txt and <mask>), which sums up to 250054, the size of the embedding table in the checkpoint.

I used the mBART50 checkpoint for a different task, so I decided to manually add the extra tokens to the dictionary. However, it seems that for the original task you can use --lang-dict to give the ML50_langs.txt path.

I hope this solves your errors!

joelb-git commented 3 years ago

@Remorax I think @gegallego is right on point.

If I don't pass --lang-dict:

fairseq-train $data_dir --save-dir $save_dir \
  --finetune-from-model $mbart_checkpoint \
  --encoder-normalize-before --decoder-normalize-before \
  --arch mbart_large --layernorm-embedding \
  --task translation_multi_simple_epoch \
  --sampling-method "temperature" \
  --sampling-temperature 1.5 \
  --encoder-langtok "src" \
  --decoder-langtok \
  --lang-pairs ka_GE-en_XX \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
  --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
  --lr-scheduler inverse_sqrt --lr $lr --warmup-updates $warmup_updates --max-update $total_num_update \
  --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
  --max-tokens 1024 --update-freq 2 \
  --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
  --seed 222 --log-format simple --log-interval 2
Traceback (most recent call last):
  File "/nas/material02/users/joelb/saral/3B/mbart-models/fairseq/fairseq/trainer.py", line 460, in load_checkpoint
    self.model.load_state_dict(
  File "/nas/material02/users/joelb/saral/3B/mbart-models/fairseq/fairseq/models/fairseq_model.py", line 125, in load_state_dict
    return super().load_state_dict(new_state_dict, strict)
  File "/nas/material02/users/joelb/saral/3B/mbart-models/fairseq/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BARTModel:
        size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([250054, 1024]) from checkpoint, the shape in current model is torch.Size([250003, 1024]).
        size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([250054, 1024]) from checkpoint, the shape in current model is torch.Size([250003, 1024]).
        size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([250054, 1024]) from checkpoint, the shape in current model is torch.Size([250003, 1024]).

This looks like the same size issue. When I add --lang-dict mbart50.ft.n1/ML50_langs.txt things work correctly.

FWIW, I think I remember working through this (many months ago), hitting the same issue because I started with the fairseq-train command from the original mbart25, which uses --langs, --source-lang, --target-lang en_XX, but later found I needed --lang-dict and --lang-pairs for mbart50.

BTW, another quirk here is that the ML50_langs.txt file appears to be only 51 lines according to wc:

$ wc -l mbart50.ft.n1/ML50_langs.txt
51 mbart50.ft.n1/ML50_langs.txt

but it really does have 52 lines. There is just no trailing newline on this file.

gegallego commented 3 years ago

I was also confused with the number of lines. I was also using wc -l and the numbers didn't add up!

Remorax commented 3 years ago

Hi @gegallego and @joelb-git

Thanks so much for both of your help! Providing ML50_langs.txt as lang-dict did work after all, and my model is training successfully :) Perhaps the README documentation can be improved on this.

Yes wc can be a little unreliable at times. The output depends on the presence of the training new line.

Anyway, thanks again for all your help!

NomadXD commented 3 years ago

@gegallego @joelb-git I am also using mBART these days and faced a similar issue. I fine tuned mBART for en_XX-si_LK translation using the following snippet where I used translation_multi_simple_epoch task.

lang_pairs="en_XX-si_LK"
path_2_data="ft/preprocess"
lang_list="mbart50.pretrained/ML50_langs.txt"  # <a file which contains a list of languages separated by new lines>
pretrained_model="mbart50.pretrained/model.pt" # <path to the pretrained model, e.g. mbart or another trained multilingual model>

!fairseq-train $path_2_data \
  --finetune-from-model $pretrained_model \
  --encoder-normalize-before --decoder-normalize-before \
  --arch mbart_large --layernorm-embedding \
  --task translation_multi_simple_epoch \
  --sampling-method "temperature" \
  --sampling-temperature 1.5 \
  --encoder-langtok "src" \
  --decoder-langtok \
  --lang-dict "$lang_list" \
  --lang-pairs "$lang_pairs" \
  --memory-efficient-fp16 \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
  --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
  --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 100000 \
  --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
  --max-tokens 1024 --update-freq 2 \
  --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
  --seed 222 --log-format simple --log-interval 2 \
  --save-dir ft/checkpoints > ft/logs 2>&1 & disown

!tail -f ft/logs

And then I copied the checkpoint_last.pt and checkpoint_best.pt from the above and tried to fine tune to si_LK text simplification where source language is the complex form and target language is the simple form. I tried to perform that as a translation_from_pretrained_bart task using the following code snippet.

path_2_data="ft/preprocess"
pretrained_model="SiTa-ft-translation/checkpoint_last.pt"

!fairseq-train $path_2_data \
  --finetune-from-model $pretrained_model \
  --encoder-normalize-before --decoder-normalize-before \
  --arch mbart_large --layernorm-embedding \
  --task translation_from_pretrained_bart \
  --langs "ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN,af_ZA,az_AZ,bn_IN,fa_IR,he_IL,hr_HR,id_ID,ka_GE,km_KH,mk_MK,ml_IN,mn_MN,mr_IN,pl_PL,ps_AF,pt_XX,sv_SE,sw_KE,ta_IN,te_IN,th_TH,tl_XX,uk_UA,ur_PK,xh_ZA,gl_ES,sl_SI" \
  --source-lang complex --target-lang simple \
  --memory-efficient-fp16 \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
  --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
  --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 20000 \
  --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
  --max-tokens 1024 --update-freq 2 \
  --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
  --seed 222 --log-format simple --log-interval 2 \
  --save-dir ft/checkpoints > ft/logs 2>&1 & disown

!tail -f ft/logs

When I try to fine tune the text simplification part, I get the following error which is similar to the one mentioned here.

RuntimeError: Error(s) in loading state_dict for BARTModel:
    size mismatch for encoder.embed_tokens.weight: copying a param with shape torch.Size([250053, 1024]) from checkpoint, the shape in current model is torch.Size([250054, 1024]).
    size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([250053, 1024]) from checkpoint, the shape in current model is torch.Size([250054, 1024]).
    size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([250053, 1024]) from checkpoint, the shape in current model is torch.Size([250054, 1024]).

There's a difference of 1 in the dimensions (250054 and 250053). I tried to debug the issue and seems like the issue is with the <mask> token. The task translation_from_pretrained_bart adds the <mask> token automatically(relevant section) and there's no special implementation like that in the task translation_multi_simple_epoch. Any idea or suggestion regarding this ??? When I use the same task for the both, it works. Also it's great if you can explain a little bit more regarding those two tasks and in which cases to use them. Thanks in advance.

gegallego commented 3 years ago

Hello @NomadXD, I'm sorry, but I cannot help you with that... I used the mBART checkpoint for a different task, so I have no idea about the tasks you mention.

Mao-KU commented 3 years ago

Hello @sergej-d I am also using mbart50-n-to-1 model to do the translation, did you obtain the translation successfully? In my case, both mbart50-n-to-1 and mbart-50-n-to-n just repeat one word again and again for translation. Here is a file that I obtained for ar-en translation(translated on n-to-n model): https://lotus.kuee.kyoto-u.ac.jp/~zhuoyuanmao/ar_AR-en_XX Also a file translated on n-to-1 model: https://lotus.kuee.kyoto-u.ac.jp/~zhuoyuanmao/ar_AR-en_XX.n1 Do you have any ideas about this problem? Thank you in advance!

ThaminduR commented 2 years ago

I used the mBART50 checkpoint for a different task, so I decided to manually add the extra tokens to the dictionary.

Hi @gegallego How can we add extra tokens to the dictionary ?

gegallego commented 2 years ago

I just added the language tokens followed by a space and a 1, at the end of the dictionary files (e.g. dict.en_XX.txt).

You can launch this command to add all language tokens:

cat $MBART_ROOT/ML50_langs.txt | cut -d'_' -f1 | sed 's/^/<lang:/g' | \
  sed 's/$/> 1/g' >> $MBART_ROOT/dict.en_XX.txt && \
echo "<mask> 1" >> $MBART_ROOT/dict.en_XX.txt

Where $MBART_ROOT is the directory where you extracted the mBART50 tar.gz file.