facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.43k stars 6.4k forks source link

comparing fairseq vs tensor2tensor #417

Closed nicolabertoldi closed 5 years ago

nicolabertoldi commented 5 years ago

We are analyzing performance of Tensor2Tensor vs performance of Fairseq.

Currently we are not able to reach the very good performance your system achieved in WMT2018. and hence T2T still outperforms Fairseq.

Probably we are using a wrong setting for the training. Could you please describe me the setting you actually used for training each of your six models composing the ensemble?

edunov commented 5 years ago

Here is the exact command we used to train the models:

python train.py $DATA/wmt18_en_de_bpej32k/ $DATA/bt_data_wmt18_en_de_bpej32k -a transformer_vaswani_wmt_en_de_big --clip-norm 0.0 --lr 0.0005 --source-lang en --target-lang de --label-smoothing 0.1 --upsample-primary $UPSAMPLE --attention-dropout 0.1 --dropout 0.3 --max-tokens 3584 --no-progress-bar --log-interval 100 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --fp16 --max-update 200000 --seed $SEED --sample-without-replacement 128000 --share-all-embeddings --optimizer adam --adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 --min-lr 1e-09 --distributed-port 12597 --distributed-world-size 128

For the ensemble of 6 models we used different seeds, and somewhat different upsample ratios: UPSAMPLE = 16 for 4 models, 8 and 32 for the remaining 2

Note, this command assumes you have 128 GPU setup (--distributed-world-size 128) If you have 8 GPUs, you can add --update-freq 16 --distributed-world-size 8

Can you post here the results you're getting with T2T vs Fairseq?

nicolabertoldi commented 5 years ago

@edunov thank you I will try asap with your configuration.

Our first comparison is based on a "transformer_base" model for T2T and "transformer_wmt_en_de" for fairseq, which have almost the same amount of parameters

We are using only 60M tokens currently.

Here are the scores on two testsets testset1

testset2

This is not very promising.

myleott commented 5 years ago

When using transformer_wmt_en_de (base), make sure to increase the learning rate.

--lr=0.0007 is a good learning rate for the base model with 8 GPUs.

nicolabertoldi commented 5 years ago

@edunov, @myleott are the performance very sensitive to the learning rate? is there any reasonable rule to choose the right learning rate when varying architecture and/or gpu and/or max_tokens?

myleott commented 5 years ago

The original "Attention is all you need" paper suggests setting the learning rate proportionally to the inverse square root of the embedding dimension [1]. Thus when using the smaller model (base) you should use a slightly larger learning rate compared to the big model.

W.r.t. batch size, larger batch sizes can support and benefit from larger learning rates. However, it's important to use a warmup (--warmup-updates 4000). In our Scaling NMT paper [2] we found that the big model could support a learning rate as large as 0.001 when trained with batches with 400K tokens (~= 128 GPUs). Larger learning rates seem to both speed up training and often result in better final performance as well.

[1] See Section 5.3 here: https://arxiv.org/pdf/1706.03762.pdf [2] See Table 1 of "Scaling Neural Machine Translation": https://arxiv.org/pdf/1806.00187.pdf

davidecaroselli commented 5 years ago

Hi @myleott

first of all, thanks for the all the help! So at the moment, as @nicolabertoldi mentioned, we are running a small test, let's say a "sanity check" just to test if the whole architecture is comparable wrt t2t transformer. Unfortunately our tests are not going well at all.

We are now training 2 equivalent systems: the first is a t2t model with transformer_big, the second is the fairseq model with the options you suggested. Both are running on 1 GPU. While the t2t model is running smoothly, the fairseq one is diverging. Here is the exact command we issued:

python3.6 ./fairseq/train.py ./data-bin/ --save-dir ./checkpoints/ -a transformer_vaswani_wmt_en_de_big --clip-norm 0.0 --lr 0.001 --label-smoothing 0.1 --attention-dropout 0.1 --dropout 0.3 --max-tokens 1536 --no-progress-bar --log-interval 1000 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --max-update 2000000 --seed 1234 --share-all-embeddings --optimizer adam --adam-betas (0.9, 0.98) --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000 --min-lr 1e-09

While this is a summary of the command output:

| [en] dictionary: 32768 types
| [it] dictionary: 32768 types
| ./data-bin/ train 2871049 examples
| ./data-bin/ valid 5000 examples
| model transformer_vaswani_wmt_en_de_big, criterion LabelSmoothedCrossEntropyCriterion
| num. model params: 209911808
| training on 1 GPUs
| max tokens per GPU = 1536 and max sentences per GPU = None
| epoch 001:   1000 / 49946 loss=11.813, nll_loss=11.290, ppl=2503.53, wps=4532, ups=3.2, wpb=1284, bsz=57, num_updates=1001, lr=6.26562e-05, gnorm=2.681, clip=0%, oom=0, wall=318, train_wall=255
| epoch 001:   2000 / 49946 loss=11.160, nll_loss=10.534, ppl=1482.79, wps=4538, ups=3.3, wpb=1290, bsz=56, num_updates=2001, lr=0.00012515, gnorm=2.284, clip=0%, oom=0, wall=603, train_wall=513
...
| epoch 001 | loss 10.739 | nll_loss 10.051 | ppl 1061.21 | wps 4515 | ups 3.5 | wpb 1293 | bsz 57 | num_updates 49946 | lr 0.000565991 | gnorm 49.466 | clip 0% | oom 0 | wall 14340 | train_wall 12897
| epoch 001 | valid on 'valid' subset | valid_loss 12.5821 | valid_nll_loss 12.0973 | valid_ppl 4381.69 | num_updates 49946
...
| epoch 002 | loss 11.099 | nll_loss 10.462 | ppl 1410.16 | wps 4510 | ups 3.5 | wpb 1293 | bsz 57 | num_updates 99892 | lr 0.000400216 | gnorm 48.665 | clip 0% | oom 0 | wall 28686 | train_wall 25807
| epoch 002 | valid on 'valid' subset | valid_loss 13.3439 | valid_nll_loss 13.0128 | valid_ppl 8265.24 | num_updates 99892 | best 12.5821
...
| epoch 005 | loss 11.091 | nll_loss 10.452 | ppl 1401.16 | wps 4518 | ups 3.5 | wpb 1293 | bsz 57 | num_updates 249730 | lr 0.000253119 | gnorm 1389.671 | clip 0% | oom 0 | wall 71688 | train_wall 64499
| epoch 005 | valid on 'valid' subset | valid_loss 14.3626 | valid_nll_loss 14.1353 | valid_ppl 17995.23 | num_updates 249730 | best 12.5821
...
| epoch 006:  40000 / 49946 loss=11.129, nll_loss=10.496, ppl=1444.59, wps=4509, ups=3.5, wpb=1294, bsz=57, num_updates=289731, lr=0.000234997, gnorm=340.567, clip=0%, oom=0, wall=83195, train_wall=74851
| epoch 006:  41000 / 49946 loss=11.127, nll_loss=10.495, ppl=1442.86, wps=4509, ups=3.5, wpb=1294, bsz=57, num_updates=290731, lr=0.000234593, gnorm=334.005, clip=0%, oom=0, wall=83479, train_wall=75107

Thanks for your help!

myleott commented 5 years ago

That learning rate (0.001) is way too big for that batch size, thus why it's diverging. Notice the words-per-batch (wpb=1290), which gives the effective batch size.

Some more context: the learning rate and batch size are very intimately connected. If you use a smaller batch size, you also need to use a smaller learning rate. I only recommend 0.001 if training with an effective batch size of 400k tokens (i.e., 128 GPUs and --max-tokens 4000). When training with an effective batch size of 25k tokens (i.e., 8 GPUs and --max-tokens 4000) then I recommend a smaller learning rate of 0.0005. Your effective batch size is 1.3k tokens.

Fortunately fairseq can simulate larger effective batch sizes using the --update-freq setting, which will aggregate multiple updates. So to get an effective batch size of ~25k tokens, then I'd recommend setting --update-freq=19 (25k/1290 ~= 19) and using --lr 0.0005.

davidecaroselli commented 5 years ago

That's great! Again thanks a lot for your detailed explanation.

We will post updates asap!

nicolabertoldi commented 5 years ago

@myleott, @edunov we are continuing our comparison between T2T and fairseq using all your suggestions about parameters and the preliminary results are promising. Thanks a lot!

Before share our info with you, we would like to know how do you decide to stop a training. Which termination policy do you usually apply? which metrics do you consider for that?

myleott commented 5 years ago

We usually look at the validation loss and stop after it's plateaued for a while. Note that best performance usually comes after averaging the last N checkpoints (with N between 5 and 10). You can do this with the average_checkpoints script: python scripts/average_checkpoints.py --inputs /path/to/checkpoints --num-epoch-checkpoints 10 --output /path/to/checkpoints/averaged.pt

nicolabertoldi commented 5 years ago

thanks @myleott, ok for the termination policy.

I see you average the last 5/10 checkpoints assuming you save a checkpoint every epoch; Actually we are saving a checkpoint every 2K step; our training makes about 20K steps/epoch; so we have more or less 10 checkpoints per epoch. I will use this parameter --num-update-checkpoints 10 instead of yours --num-epoch-checkpoints 10

Do you think that it is still reasonable averaging the last 10 checkpoints (saved every 2K steps), or it is better to increase the number of checkpoints to take into account in the average?

myleott commented 5 years ago

Averaging the last 10 checkpoints saved every 2k steps should be fine. I think this is closer to what Vaswani et al. did too.

nicolabertoldi commented 5 years ago

@myleott thank you for your suggestions

xforceco commented 5 years ago

Just drop by here and have a quick question for the learning rate setting for 1 GPU only. How shall I set the following hyperparameters for 1 GPU training correctly to reproduce the results? lr, update-freq, and max-tokens Previously I used lr=0.0005, update-freq=1 and max-tokens=3000 in one GPU, and there is no much learning starting in the 7th epoch, in which loss stuck at ~7.02 I am also wondering what is the rough converged loss supposed to be for the --arch transformer_vaswani_wmt_en_de_big experiment as well

myleott commented 5 years ago

@xforceco, you should use --update-freq=8 (or maybe up to 10) when training on 1 GPU. Look at the training log, you should see words-per-batch (wpb). This should be around 25k. When using a single GPU and --update-freq=1, as you did, the effective batch size (wpb) will be too small.

The converged validation loss should be ~2.06.

davidecaroselli commented 5 years ago

Hello @myleott and @edunov

we have just completed our first tests on fairseq and the results are very good! Not only the training speed is higher than tensor2tensor but, most important, on the same exact training data, fairseq is giving us better results than t2t (probably thanks to 'update-freq' feature?).

Here are the final BLEU scores on our testset: fairseq: 36.59 T2T: 35.41

Thanks for all your amazing support!

xforceco commented 5 years ago

@myleott Just saw the intermediate training log, your suggestion seems to be working very well. Thank you very much. *EDIT: To clarify, I got valid_nll_loss ~2.5 after epoch 3 with 15K updates, does it have a reasonable progress?

myleott commented 5 years ago

@xforceco, yep that's similar to what I get:

{"epoch": 1, "valid_loss": 4.9776272644057125, "valid_nll_loss": 3.263508564650427, "valid_ppl": "9.60"}
{"epoch": 2, "valid_loss": 4.430108650668893, "valid_nll_loss": 2.684877586701665, "valid_ppl": "6.43"}
{"epoch": 3, "valid_loss": 4.245328851412793, "valid_nll_loss": 2.496941800665174, "valid_ppl": "5.64"}
{"epoch": 4, "valid_loss": 4.15511640321181, "valid_nll_loss": 2.402382043843173, "valid_ppl": "5.29"}
{"epoch": 5, "valid_loss": 4.089958693182424, "valid_nll_loss": 2.338727264762895, "valid_ppl": "5.06"}
{"epoch": 6, "valid_loss": 4.053738931742309, "valid_nll_loss": 2.304074601711548, "valid_ppl": "4.94"}
{"epoch": 7, "valid_loss": 4.026322747002758, "valid_nll_loss": 2.268473556725225, "valid_ppl": "4.82"}
{"epoch": 8, "valid_loss": 3.996494588912812, "valid_nll_loss": 2.237285659433318, "valid_ppl": "4.72"}
{"epoch": 9, "valid_loss": 3.983175458325269, "valid_nll_loss": 2.227702413029846, "valid_ppl": "4.68"}
{"epoch": 10, "valid_loss": 3.9605515295527676, "valid_nll_loss": 2.2074773416315474, "valid_ppl": "4.62"}
{"epoch": 11, "valid_loss": 3.9444221234610377, "valid_nll_loss": 2.190658053450889, "valid_ppl": "4.57"}
{"epoch": 12, "valid_loss": 3.9395180726592143, "valid_nll_loss": 2.1856860732464756, "valid_ppl": "4.55"}
{"epoch": 13, "valid_loss": 3.9237129949933083, "valid_nll_loss": 2.1751739871773843, "valid_ppl": "4.52"}
{"epoch": 14, "valid_loss": 3.926290171509098, "valid_nll_loss": 2.171692958404979, "valid_ppl": "4.51"}
{"epoch": 15, "valid_loss": 3.9146234881228517, "valid_nll_loss": 2.160284582831526, "valid_ppl": "4.47"}
{"epoch": 16, "valid_loss": 3.9060891609331176, "valid_nll_loss": 2.158919684965754, "valid_ppl": "4.47"}
{"epoch": 17, "valid_loss": 3.899397612133591, "valid_nll_loss": 2.1512093535810477, "valid_ppl": "4.44"}
{"epoch": 18, "valid_loss": 3.8987179256462876, "valid_nll_loss": 2.146370778773354, "valid_ppl": "4.43"}
{"epoch": 19, "valid_loss": 3.8984733676243635, "valid_nll_loss": 2.148081701674582, "valid_ppl": "4.43"}
{"epoch": 20, "valid_loss": 3.8788223213997437, "valid_nll_loss": 2.132689921546082, "valid_ppl": "4.39"}
{"epoch": 21, "valid_loss": 3.878543980216159, "valid_nll_loss": 2.1293252743699487, "valid_ppl": "4.38"}
{"epoch": 22, "valid_loss": 3.879500282735846, "valid_nll_loss": 2.131955859216356, "valid_ppl": "4.38"}
{"epoch": 24, "valid_loss": 3.8724092669906605, "valid_nll_loss": 2.12616384874608, "valid_ppl": "4.37"}
{"epoch": 25, "valid_loss": 3.8700676029959706, "valid_nll_loss": 2.12229279938514, "valid_ppl": "4.35"}
{"epoch": 26, "valid_loss": 3.864768752810862, "valid_nll_loss": 2.1173498193501894, "valid_ppl": "4.34"}
{"epoch": 27, "valid_loss": 3.8671614955524296, "valid_nll_loss": 2.1221588745542084, "valid_ppl": "4.35"}
{"epoch": 28, "valid_loss": 3.8670298509852215, "valid_nll_loss": 2.119325281176031, "valid_ppl": "4.34"}
{"epoch": 29, "valid_loss": 3.8658127337268167, "valid_nll_loss": 2.1193866699049972, "valid_ppl": "4.35"}
{"epoch": 30, "valid_loss": 3.8663510732454474, "valid_nll_loss": 2.1216935118696014, "valid_ppl": "4.35"}
{"epoch": 31, "valid_loss": 3.862011436510338, "valid_nll_loss": 2.1204680605778394, "valid_ppl": "4.35"}
{"epoch": 32, "valid_loss": 3.8538177532274953, "valid_nll_loss": 2.114586016942868, "valid_ppl": "4.33"}
{"epoch": 33, "valid_loss": 3.85449387818002, "valid_nll_loss": 2.115149830774492, "valid_ppl": "4.33"}
{"epoch": 34, "valid_loss": 3.8585339094590756, "valid_nll_loss": 2.1132496690702416, "valid_ppl": "4.33"}
{"epoch": 35, "valid_loss": 3.8584225007895325, "valid_nll_loss": 2.1155223742381755, "valid_ppl": "4.33"}
{"epoch": 36, "valid_loss": 3.858366456461459, "valid_nll_loss": 2.1117439468180943, "valid_ppl": "4.32"}
{"epoch": 37, "valid_loss": 3.850525921283631, "valid_nll_loss": 2.1101350575859037, "valid_ppl": "4.32"}
{"epoch": 38, "valid_loss": 3.8543249686678864, "valid_nll_loss": 2.1134291624479564, "valid_ppl": "4.33"}
{"epoch": 39, "valid_loss": 3.853496583381426, "valid_nll_loss": 2.1123668072026014, "valid_ppl": "4.32"}
{"epoch": 40, "valid_loss": 3.8554913807522997, "valid_nll_loss": 2.1123970760507755, "valid_ppl": "4.32"}
{"epoch": 41, "valid_loss": 3.845356717680398, "valid_nll_loss": 2.106304198791268, "valid_ppl": "4.31"}

Please re-open if you have any more questions!

gxzks commented 5 years ago

@myleott I trained an English-to-Chinese model using DynamicConv, but the result is slower than what I trained on t2t by about 6 BLEU. My training script is here:

python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ $BIN_DATA --fp16 --log-interval 100 --no-progress-bar \ --source-lang en --target-lang zh \ --max-update 40000 --optimizer adam \ --adam-betas '(0.9, 0.98)' --lr-scheduler inverse_sqrt \ --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --min-lr 1e-09 --update-freq 16 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 \ --t-mult 1 --lr-period-updates 20000 \ --arch $ARCH --save-dir $SAVE \ --dropout 0.25 --attention-dropout 0.2 --weight-dropout 0.2 \ --encoder-glu 1 --decoder-glu 1

did I set max-updates too small? I trained on t2t with train_steps setting to 500k using 8 gpus without accumulating the gradients.

gxzks commented 5 years ago

My fault, I should apply bpe first to the test file. After bpe. the result is normal.