facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.17k stars 6.37k forks source link

nonauto-regressive-nmt bleu 0.0 #1231

Closed MTNewer closed 4 years ago

MTNewer commented 4 years ago

Hi, Thank you for publishing nonautoregressive mt code. I have tried it following your introduction. The training loss decreases normally. But I got bleu 0.0 in newstest2014 en2de.

Data is preprocessed following https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2de.sh

Training script: python train.py \ ../data-bin/wmt14_en_de \ --save-dir ../checkpoints/Non-autoregressive_Transformer/ \ --ddp-backend=no_c10d \ --task translation_lev \ --criterion nat_loss \ --arch nonautoregressive_transformer \ --noise full_mask \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ --min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ --encoder-learned-pos \ --pred-length-offset \ --length-loss-factor 0.1 \ --apply-bert-init \ --log-format json --log-interval 100 \ --fixed-validation-seed 7 \ --max-tokens 8000 \ --save-interval-updates 1000 \ --keep-interval-updates 300 \ --max-update 300000

Testing script: python generate.py \ ../data-bin/wmt14_en_de/ \ --gen-subset test \ --task translation_lev \ --path ../checkpoints/Non-autoregressive_Transformer/checkpoint_best.pt \ --iter-decode-max-iter 9 \ --iter-decode-eos-penalty 0 \ --beam 1 --remove-bpe \ --print-step \ --batch-size 400

Result: | Translated 2737 sentences (81252 tokens) in 12.0s (228.84 sentences/s, 6793.32 tokens/s) | Generate test with beam=1: BLEU4 = 0.00, 17.6/1.0/0.0/0.0 (BP=0.727, ratio=0.758, syslen=44974, reflen=59311)

Is there anything wrong with my experiment?

MultiPath commented 4 years ago

Hi,

The arguments look ok to me. However, the number of sentences looks suspicious, which is 2737 vs 3003 (in the standard case). Make sure you preprocess the dataset correctly?

MTNewer commented 4 years ago

Hi I evaluated our model on the standard version (3003) and the bleu is 0.0 as well. I am not sure where the error occurred, training stage or test stage? Following is a part of log at the final phase of the training stage:

{"epoch": 146, "update": 145.341, "loss": "7.058", "nll_loss": "5.295", "ppl": "39.27", "wps": "155163", "ups": "2", "wpb": "59781.572", "bsz": "1905.601", "num_updates": "298386", "lr": "9.15337e-05", "gnorm": "0.971", "clip": "0.000", "oom": "0.000", "wall": "117127", "train_wall": "113107", "w_ins": "6.6492", "len": "4.08802"} {"epoch": 146, "update": 145.39, "loss": "7.055", "nll_loss": "5.292", "ppl": "39.17", "wps": "155406", "ups": "2", "wpb": "59753.975", "bsz": "1911.154", "num_updates": "298486", "lr": "9.15183e-05", "gnorm": "0.971", "clip": "0.000", "oom": "0.000", "wall": "117165", "train_wall": "113145", "w_ins": "6.64596", "len": "4.08617"} {"epoch": 146, "update": 145.438, "loss": "7.057", "nll_loss": "5.294", "ppl": "39.23", "wps": "155487", "ups": "2", "wpb": "59764.698", "bsz": "1906.941", "num_updates": "298586", "lr": "9.1503e-05", "gnorm": "0.969", "clip": "0.000", "oom": "0.000", "wall": "117203", "train_wall": "113183", "w_ins": "6.64806", "len": "4.08764"} {"epoch": 146, "update": 145.487, "loss": "7.065", "nll_loss": "5.302", "ppl": "39.46", "wps": "155506", "ups": "2", "wpb": "59739.604", "bsz": "1905.728", "num_updates": "298686", "lr": "9.14877e-05", "gnorm": "0.969", "clip": "0.000", "oom": "0.000", "wall": "117242", "train_wall": "113221", "w_ins": "6.65545", "len": "4.09434"} {"epoch": 146, "update": 145.536, "loss": "7.067", "nll_loss": "5.305", "ppl": "39.53", "wps": "155602", "ups": "3", "wpb": "59742.559", "bsz": "1904.510", "num_updates": "298786", "lr": "9.14724e-05", "gnorm": "0.969", "clip": "0.000", "oom": "0.000", "wall": "117280", "train_wall": "113259", "w_ins": "6.65734", "len": "4.0964"} {"epoch": 146, "update": 145.585, "loss": "7.067", "nll_loss": "5.305", "ppl": "39.54", "wps": "155675", "ups": "3", "wpb": "59735.918", "bsz": "1904.368", "num_updates": "298886", "lr": "9.14571e-05", "gnorm": "0.968", "clip": "0.000", "oom": "0.000", "wall": "117318", "train_wall": "113297", "w_ins": "6.65761", "len": "4.0954"} {"epoch": 146, "update": 145.633, "loss": "7.066", "nll_loss": "5.305", "ppl": "39.52", "wps": "155787", "ups": "3", "wpb": "59759.158", "bsz": "1904.672", "num_updates": "298986", "lr": "9.14418e-05", "gnorm": "0.970", "clip": "0.000", "oom": "0.000", "wall": "117356", "train_wall": "113335", "w_ins": "6.65705", "len": "4.09251"} {"epoch": 146, "valid_loss": "7.914", "valid_nll_loss": "6.215", "valid_ppl": "74.26", "valid_num_updates": "299000", "valid_best_loss": "7.58654", "valid_w_ins": "7.56595", "valid_len": "4.02043"} | saved checkpoint ../checkpoints/Non-autoregressive_Transformer/checkpoint_146_299000.pt (epoch 146 @ 299000 updates) (writing took 1.2446017265319824 seconds) {"epoch": 146, "update": 145.682, "loss": "7.064", "nll_loss": "5.303", "ppl": "39.47", "wps": "155200", "ups": "3", "wpb": "59754.620", "bsz": "1909.352", "num_updates": "299086", "lr": "9.14265e-05", "gnorm": "0.969", "clip": "0.000", "oom": "0.000", "wall": "117396", "train_wall": "113373", "w_ins": "6.65529", "len": "4.08988"} {"epoch": 146, "update": 145.731, "loss": "7.062", "nll_loss": "5.300", "ppl": "39.40", "wps": "155286", "ups": "3", "wpb": "59755.431", "bsz": "1917.326", "num_updates": "299186", "lr": "9.14112e-05", "gnorm": "0.971", "clip": "0.000", "oom": "0.000", "wall": "117435", "train_wall": "113411", "w_ins": "6.65286", "len": "4.08712"} {"epoch": 146, "update": 145.779, "loss": "7.063", "nll_loss": "5.302", "ppl": "39.45", "wps": "155360", "ups": "3", "wpb": "59756.429", "bsz": "1918.777", "num_updates": "299286", "lr": "9.13959e-05", "gnorm": "0.971", "clip": "0.000", "oom": "0.000", "wall": "117473", "train_wall": "113449", "w_ins": "6.65464", "len": "4.08798"} {"epoch": 146, "update": 145.828, "loss": "7.064", "nll_loss": "5.303", "ppl": "39.49", "wps": "155430", "ups": "3", "wpb": "59756.219", "bsz": "1916.657", "num_updates": "299386", "lr": "9.13807e-05", "gnorm": "0.973", "clip": "0.000", "oom": "0.000", "wall": "117511", "train_wall": "113486", "w_ins": "6.65566", "len": "4.08811"} {"epoch": 146, "update": 145.877, "loss": "7.063", "nll_loss": "5.302", "ppl": "39.46", "wps": "155454", "ups": "3", "wpb": "59756.188", "bsz": "1917.096", "num_updates": "299486", "lr": "9.13654e-05", "gnorm": "0.972", "clip": "0.000", "oom": "0.000", "wall": "117549", "train_wall": "113525", "w_ins": "6.65473", "len": "4.0866"} {"epoch": 146, "update": 145.925, "loss": "7.062", "nll_loss": "5.301", "ppl": "39.42", "wps": "155496", "ups": "3", "wpb": "59757.929", "bsz": "1919.016", "num_updates": "299586", "lr": "9.13501e-05", "gnorm": "0.972", "clip": "0.000", "oom": "0.000", "wall": "117588", "train_wall": "113563", "w_ins": "6.65336", "len": "4.08424"} {"epoch": 146, "update": 145.974, "loss": "7.066", "nll_loss": "5.305", "ppl": "39.53", "wps": "155536", "ups": "3", "wpb": "59748.152", "bsz": "1915.519", "num_updates": "299686", "lr": "9.13349e-05", "gnorm": "0.972", "clip": "0.000", "oom": "0.000", "wall": "117626", "train_wall": "113600", "w_ins": "6.65702", "len": "4.08669"} {"epoch": 146, "train_loss": "7.067", "train_nll_loss": "5.306", "train_ppl": "39.57", "train_wps": "155499", "train_ups": "3", "train_wpb": "59739.525", "train_bsz": "1915.450", "train_num_updates": "299738", "train_lr": "9.1327e-05", "train_gnorm": "0.971", "train_clip": "0.000", "train_oom": "0.000", "train_wall": "117646", "train_train_wall": "113620", "train_w_ins": "6.65837", "train_len": "4.08704"} {"epoch": 146, "valid_loss": "7.879", "valid_nll_loss": "6.176", "valid_ppl": "72.33", "valid_num_updates": "299738", "valid_best_loss": "7.58654", "valid_w_ins": "7.53562", "valid_len": "3.96013"} | saved checkpoint ../checkpoints/Non-autoregressive_Transformer/checkpoint146.pt (epoch 146 @ 299738 updates) (writing took 1.2365179061889648 seconds) {"epoch": 147, "update": 146.049, "loss": "7.100", "nll_loss": "5.337", "ppl": "40.41", "wps": "155941", "ups": "2", "wpb": "59393.624", "bsz": "1903.683", "num_updates": "299839", "lr": "9.13116e-05", "gnorm": "0.961", "clip": "0.000", "oom": "0.000", "wall": "117702", "train_wall": "113659", "w_ins": "6.68563", "len": "4.14182"} {"epoch": 147, "update": 146.097, "loss": "7.095", "nll_loss": "5.332", "ppl": "40.28", "wps": "156035", "ups": "2", "wpb": "59434.940", "bsz": "1883.990", "num_updates": "299939", "lr": "9.12964e-05", "gnorm": "0.966", "clip": "0.000", "oom": "0.000", "wall": "117740", "train_wall": "113697", "w_ins": "6.68113", "len": "4.13613"} {"epoch": 147, "valid_loss": "7.888", "valid_nll_loss": "6.184", "valid_ppl": "72.70", "valid_num_updates": "300000", "valid_best_loss": "7.58654", "valid_w_ins": "7.54538", "valid_len": "3.95957"} | saved checkpoint ../checkpoints/Non-autoregressive_Transformer/checkpoint_147_300000.pt (epoch 147 @ 300000 updates) (writing took 1.2359647750854492 seconds) {"epoch": 147, "train_loss": "7.075", "train_nll_loss": "5.312", "train_ppl": "39.74", "train_wps": "152636", "train_ups": "2", "train_wpb": "59528.237", "train_bsz": "1889.718", "train_num_updates": "300000", "train_lr": "9.12871e-05", "train_gnorm": "0.966", "train_clip": "0.000", "train_oom": "0.000", "train_wall": "117766", "train_train_wall": "113720", "train_w_ins": "6.66368", "train_len": "4.11232"} {"epoch": 147, "valid_loss": "7.888", "valid_nll_loss": "6.184", "valid_ppl": "72.70", "valid_num_updates": "300000", "valid_best_loss": "7.58654", "valid_w_ins": "7.54538", "valid_len": "3.95957"} | saved checkpoint ../checkpoints/Non-autoregressive_Transformer/checkpoint_147_300000.pt (epoch 147 @ 300000 updates) (writing took 1.2770981788635254 seconds) | done training in 117749.3 seconds

Is this training loss correct? Could you help me check it? Thank you.

MultiPath commented 4 years ago

Hi,

Thanks for providing detailed messages. You are running the vanilla nonautoregressive NMT, so for decoding you cannot set the maximum iteration to 9 since there is no iterative refinement. Can you try and let me know what is the result?

python generate.py 
../data-bin/wmt14_en_de/ 
--gen-subset test 
--task translation_lev 
--path ../checkpoints/Non-autoregressive_Transformer/checkpoint_best.pt 
--iter-decode-max-iter 0
--beam 1 --remove-bpe 
--print-step 
--batch-size 400
MTNewer commented 4 years ago

HI,

Thanks for your help. I tried the new script. Following is the result:

Iter=0: | Translated 3003 sentences (88856 tokens) in 4.2s (722.83 sentences/s, 21387.96 tokens/s) | Generate test with beam=1: BLEU4 = 7.61, 37.5/11.6/4.2/1.8 (BP=1.000, ratio=1.045, syslen=67415, reflen=64506) Iter=1: | Translated 3003 sentences (88856 tokens) in 4.9s (608.84 sentences/s, 18015.12 tokens/s) | Generate test with beam=1: BLEU4 = 0.21, 18.2/1.2/0.0/0.0 (BP=0.996, ratio=0.996, syslen=64251, reflen=64506) Iter=2: | Translated 3003 sentences (88856 tokens) in 6.3s (477.82 sentences/s, 14138.37 tokens/s) | Generate test with beam=1: BLEU4 = 0.22, 18.4/1.2/0.0/0.0 (BP=0.871, ratio=0.879, syslen=56683, reflen=64506)

When without iterative refinement, our model obtain a belu score 7.61.

MultiPath commented 4 years ago

Yeah, it is kind-of normal. (1) Iterative refinement is not implemented for vanilla NAT. So it will get wired results. Thanks for pointing out, and we should assert that. (2) This model is trained on En-De without knowledge distillation. It is pretty difficult task for this simplest nonautoregressive model (no advanced techniques at all). If you trained with more updates, it should get slightly higher BLEU scores. (3) Usually, we found sometimes averaging last 5 checkpoints will give better BLEU scores.

MTNewer commented 4 years ago

Thank you very much. Your suggestion is very useful. I appreciate it that you take the time on my issue.