facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

How to Finetune the M2M_100 12B Model for Translation Task? #4213

Open kaiyuhwang opened 2 years ago

kaiyuhwang commented 2 years ago

❓ Questions and Help

While finetuning with m2m_100 12B Model(12b_last_chk_2_gpus.pt) in Nvidia A100 and 2 40GB GPUs recommended config setting from doc getting an error everytime.

In addition, I generate successfully by the similar code. It means the balance strategy can be run. The error just appears at the finetuning time.

Thanks for your attention.

The error appears at the end of "2022-02-14 22:13:35 | INFO | fairseq.model_parallel.models.pipeline_parallel_transformer.model | Using fairscale pipe"

File "/data/miniconda3/envs/lib/python3.7/site-packages/fairscale/nn/pipe/pipe.py", line 275, in init raise ValueError(recommend_auto_balance("balance is required")) ValueError: balance is required

If your model is still under development, its optimal balance would change frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for naive automatic balancing:

from fairscale.nn import Pipe from fairscale.nn.pipe.balance import balance_by_time

partitions = torch.cuda.device_count() sample = torch.empty(...) balance = balance_by_time(partitions, model, sample)

So, it seems that the error is caused by "fairscale"

Code

This is the bash script for the finetuning process. export CUDA_VISIBLE_DEVICES=0,1

fairseq-train data-bin \ --finetune-from-model m2m/12b_last_chk_2_gpus.pt \ --save-dir checkpoints/ \ --task translation_multi_simple_epoch \ --encoder-normalize-before \ --lang-dict m2m/lang_dicts.txt \ --lang-pairs en-de \ --max-tokens 1024 \ --decoder-normalize-before \ --encoder-langtok src \ --decoder-langtok \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ --lr-scheduler inverse_sqrt --lr 3e-05 \ --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 \ --weight-decay 0.0 --update-freq 2 \ --save-interval 1 --save-interval-updates 2500 \ --keep-interval-updates 10 --no-epoch-checkpoints \ --seed 222 --log-format simple --log-interval 200 \ --patience 10 \ --arch transformer_wmt_en_de_big_pipeline_parallel \ --encoder-layers 24 --decoder-layers 24 \ --encoder-attention-heads 16 --decoder-attention-heads 16 \ --encoder-ffn-embed-dim 16384 --decoder-ffn-embed-dim 16384 \ --decoder-embed-dim 4096 --encoder-embed-dim 4096 \ --share-decoder-input-output-embed \ --share-all-embeddings \ --ddp-backend no_c10d \ --dataset-impl mmap \ --pipeline-model-parallel \ --pipeline-chunks 1 \ --pipeline-encoder-balance '[26]' \ --pipeline-encoder-devices '[0]' \ --pipeline-decoder-balance '[3,22,1]' \ --pipeline-decoder-devices '[0,1,0]'

What's your environment?

wei-ann-Github commented 2 years ago

hi @koukaiu , have you had any luck resolving how to finetune the 12B model? I am having much difficulty as well.

Thank you very much.