facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.37k stars 6.4k forks source link

Reproducing result on WMT14' en-fr #85

Closed Zrachel closed 6 years ago

Zrachel commented 6 years ago

Following the latest code with training parameter specified by @edunov in https://github.com/facebookresearch/fairseq-py/issues/41 and Readme.md of Pretrained-models, I got exploding update on WMT14 en-fr:

+ miniconda3/bin/python3 PyFairseq/train.py data-bin --save-dir model -s en -t fr --arch fconv_wmt_en_fr --dropout 0.1 --lr 2.5 --clip-norm 0.1 --max-tokens 4000 --force-anneal 32 
Namespace(adam_betas='(0.9, 0.999)', arch='fconv_wmt_en_fr', clip_norm=0.1, curriculum=0, data='data-bin', decoder_attention='True', decoder_embed_dim=768, decoder_layers='[(512, 3)] * 6 + [(768, 3)] * 4 + [(1024, 3)] * 3 + [(2048, 1)] * 1 + [(4096, 1)] * 1', decoder_out_embed_dim=512, dropout=0.1, encoder_embed_dim=768, encoder_layers='[(512, 3)] * 6 + [(768, 3)] * 4 + [(1024, 3)] * 3 + [(2048, 1)] * 1 + [(4096, 1)] * 1', force_anneal=32, label_smoothing=0, log_format=None, log_interval=1000, lr='2.5', lrshrink=0.1, max_epoch=0, max_sentences=None, max_source_positions=1024, max_target_positions=1024, max_tokens=4000, min_lr=1e-05, model='fconv', momentum=0.99, no_epoch_checkpoints=False, no_progress_bar=False, no_save=False, num_gpus=8, optimizer='nag', restore_file='checkpoint_last.pt', sample_without_replacement=0, save_dir='model', save_interval=-1, seed=1, sentence_avg=False, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='fr', train_subset='train', valid_subset='valid', weight_decay=0.0, workers=1) 
| [en] dictionary: 43881 types 
| [fr] dictionary: 43978 types 
| data-bin train 35482842 examples 
| data-bin valid 26663 examples 
| using 8 GPUs (with max tokens per GPU = 4000 and max sentences per GPU = None) 
| model fconv_wmt_en_fr, criterion CrossEntropyCriterion 
Warning! 1 samples are either too short or too long and will be ignored, first few sample ids=[28743556] 
| epoch 001: 1000 / 331737 loss=9.21 (10.89), wps=16319, wpb=31291, bsz=850, lr=2.5, clip=100%, gnorm=2.50713, oom=0 
| epoch 001: 2000 / 331737 loss=588.92 (19.76), wps=16417, wpb=31241, bsz=838, lr=2.5, clip=100%, gnorm=5.39344e+09, oom=0 
| epoch 001: 3000 / 331737 loss=126867869305.41 (3395251823.97), wps=16436, wpb=31258, bsz=849, lr=2.5, clip=100%, gnorm=2.05028e+16, oom=0 
| epoch 001: 4000 / 331737 loss=137727644131954352.00 (3821157344375131.00), wps=16438, wpb=31229, bsz=853, lr=2.5, clip=100%, gnorm=inf, oom=0 
| epoch 001: 5000 / 331737 loss=358248860949876800.00 (64219013624718560.00), wps=16454, wpb=31251, bsz=861, lr=2.5, clip=100%, gnorm=inf, oom=0 
| epoch 001: 6000 / 331737 loss=74803270219822464.00 (85287362140370208.00), wps=16464, wpb=31255, bsz=857, lr=2.5, clip=100%, gnorm=inf, oom=0 
| epoch 001: 7000 / 331737 loss=1124810776667683.12 (75791177781467504.00), wps=16478, wpb=31266, bsz=854, lr=2.5, clip=100%, gnorm=inf, oom=0 
| epoch 001: 8000 / 331737 loss=nan (nan), wps=16486, wpb=31252, bsz=852, lr=2.5, clip=94%, gnorm=nan, oom=0 
| epoch 001: 9000 / 331737 loss=nan (nan), wps=16493, wpb=31241, bsz=852, lr=2.5, clip=83%, gnorm=nan, oom=0 
| epoch 001: 10000 / 331737 loss=nan (nan), wps=16502, wpb=31244, bsz=855, lr=2.5, clip=75%, gnorm=nan, oom=0 
| epoch 001: 11000 / 331737 loss=nan (nan), wps=16511, wpb=31239, bsz=855, lr=2.5, clip=68%, gnorm=nan, oom=0 
| epoch 001: 12000 / 331737 loss=nan (nan), wps=16521, wpb=31240, bsz=855, lr=2.5, clip=62%, gnorm=nan, oom=0 
| epoch 001: 13000 / 331737 loss=nan (nan), wps=16529, wpb=31244, bsz=853, lr=2.5, clip=58%, gnorm=nan, oom=0 
| epoch 001: 14000 / 331737 loss=nan (nan), wps=16536, wpb=31239, bsz=851, lr=2.5, clip=53%, gnorm=nan, oom=0 
| epoch 001: 15000 / 331737 loss=nan (nan), wps=16539, wpb=31236, bsz=852, lr=2.5, clip=50%, gnorm=nan, oom=0 

Only change the learning rate to 1.25 would not trigger the exploding problem, but BLEU increases very slow:

checkpoint1.pt/test.bleu:BLEU4 = 30.11, 59.5/35.8/23.8/16.2 (BP=1.000, ratio=0.975, syslen=83264, reflen=81204)
checkpoint2.pt/test.bleu:BLEU4 = 31.34, 60.4/37.1/25.0/17.2 (BP=1.000, ratio=0.986, syslen=82348, reflen=81204)
checkpoint3.pt/test.bleu:BLEU4 = 32.56, 61.4/38.4/26.1/18.2 (BP=1.000, ratio=0.988, syslen=82230, reflen=81204)
checkpoint4.pt/test.bleu:BLEU4 = 32.71, 61.5/38.5/26.3/18.4 (BP=1.000, ratio=0.989, syslen=82140, reflen=81204)
checkpoint5.pt/test.bleu:BLEU4 = 33.13, 62.0/38.9/26.7/18.7 (BP=1.000, ratio=0.997, syslen=81437, reflen=81204)
checkpoint6.pt/test.bleu:BLEU4 = 33.04, 61.5/38.8/26.7/18.7 (BP=1.000, ratio=0.995, syslen=81632, reflen=81204)
checkpoint7.pt/test.bleu:BLEU4 = 33.01, 61.6/38.8/26.6/18.7 (BP=1.000, ratio=0.987, syslen=82282, reflen=81204)
checkpoint8.pt/test.bleu:BLEU4 = 33.60, 62.2/39.4/27.2/19.1 (BP=1.000, ratio=0.992, syslen=81830, reflen=81204)
checkpoint9.pt/test.bleu:BLEU4 = 33.07, 61.6/38.9/26.7/18.7 (BP=1.000, ratio=0.993, syslen=81783, reflen=81204)
checkpoint10.pt/test.bleu:BLEU4 = 33.39, 62.2/39.3/27.0/19.0 (BP=0.999, ratio=1.001, syslen=81099, reflen=81204)
checkpoint11.pt/test.bleu:BLEU4 = 33.74, 62.5/39.6/27.3/19.2 (BP=1.000, ratio=0.993, syslen=81744, reflen=81204)
checkpoint12.pt/test.bleu:BLEU4 = 33.37, 61.8/39.1/27.0/19.0 (BP=1.000, ratio=0.992, syslen=81892, reflen=81204)
checkpoint13.pt/test.bleu:BLEU4 = 34.07, 62.6/39.9/27.6/19.5 (BP=1.000, ratio=0.996, syslen=81534, reflen=81204)
checkpoint14.pt/test.bleu:BLEU4 = 33.81, 62.4/39.6/27.4/19.3 (BP=1.000, ratio=0.994, syslen=81685, reflen=81204)
checkpoint15.pt/test.bleu:BLEU4 = 33.78, 62.6/39.7/27.3/19.2 (BP=0.999, ratio=1.001, syslen=81110, reflen=81204)
checkpoint16.pt/test.bleu:BLEU4 = 34.09, 62.8/39.9/27.6/19.5 (BP=1.000, ratio=0.994, syslen=81723, reflen=81204)
checkpoint17.pt/test.bleu:BLEU4 = 33.94, 62.3/39.7/27.5/19.5 (BP=1.000, ratio=0.990, syslen=81988, reflen=81204)
checkpoint18.pt/test.bleu:BLEU4 = 34.43, 62.8/40.2/28.0/19.9 (BP=1.000, ratio=0.993, syslen=81811, reflen=81204)
checkpoint19.pt/test.bleu:BLEU4 = 34.14, 62.6/40.0/27.7/19.6 (BP=1.000, ratio=0.994, syslen=81661, reflen=81204)
checkpoint20.pt/test.bleu:BLEU4 = 34.05, 62.5/39.9/27.6/19.6 (BP=1.000, ratio=0.999, syslen=81314, reflen=81204)
checkpoint21.pt/test.bleu:BLEU4 = 34.20, 62.8/40.0/27.8/19.6 (BP=1.000, ratio=0.999, syslen=81259, reflen=81204)
checkpoint22.pt/test.bleu:BLEU4 = 34.13, 62.4/40.0/27.7/19.6 (BP=1.000, ratio=0.998, syslen=81331, reflen=81204)
checkpoint23.pt/test.bleu:BLEU4 = 34.31, 62.6/40.1/27.9/19.8 (BP=1.000, ratio=0.991, syslen=81972, reflen=81204)
checkpoint26.pt/test.bleu:BLEU4 = 34.11, 62.9/40.1/27.7/19.4 (BP=1.000, ratio=0.999, syslen=81260, reflen=81204)

My question is: Is the results I got within expectation? Should I wait for the result of lr=1.25, or there is something wrong with my data/config?

myleott commented 6 years ago

Yep, I'm seeing the same thing. I'll investigate a bit and get back to you shortly.

myleott commented 6 years ago

I dug into it a bit. It seems lr=2.5 is fragile and are sensitive to the random seed. We also changed the way we seed the RNG (e.g., 104cead16ef010465228635158ae02b44b2e8210; 5ef59abd1fb2cde1615d316ecc5185ee7b9ccfc7), which means that the default seed of 1 no longer produces the same results as we had before, and instead produces the "exploding" behavior you observed.

Usually we try to find hyper-parameters that work well across several random seeds, but in this case, either because of changes in the code or luck when originally tuning the lr, this configuration seems to be quite unstable. You can try with lr=1.25, which should be more stable and give comparable results, although I'm not sure why your lr=1.25 run seems to plateau around BLEU=34...... You can also try several seeds with lr=2.5 until one makes it past ~10k updates (e.g., I tried seed=10 and it worked).

edunov commented 6 years ago

Hi @Zrachel

I still don't understand why dictionary files that you had shared with us are encoded with ISO-8859-1, that might be a clue to why you have such a low BLEU score. So, let me ask a few questions:

1) How do you measure the BLEU score? Do you use our code or something else? 2) What does echo $LANG report on your system? If it is nor "en_US.UTF-8" can you try export LAND="en_US.UTF-8" and re-do everything, including data preparation?
3) Did you move your data between linux and windows? 4) Did you mix data with lua torch fairseq? 5) What is apply_bpe_fix.py ? How is it different from apply_bpe.py ?

In theory, you don't even need to train for that many epochs to see that something is wrong, after epoch 3, you should have solid 37-38 BLEU, regardless of the learning rate (unless it's really small or too big), 1.25 is a pretty good learning rate, so there must be some other problem.

Can you also try training on IWSLT and see if you can reach BLEU > 31 with 8 GPU and max-tokens 1000 ? (Using the prepare script we provided in data)

Another thing, maybe try WMT14 En2De and see if you can achieve BLEU > 25?

That warning you see "Warning! 1 samples are either too short or too long and will be ignored, sample ids=[28743556]" is also very suspicious, but I'm running out of ideas regarding this one

Zrachel commented 6 years ago

Thank you @myleott and @edunov .

  1. BLEU: Yes, I use the measurement code in fairseq-py.
  2. echo $LANG: en_US. OK, I'll change it to en_US.UTF-8 and try again
  3. NO, I only used linux.
  4. NO.
  5. I only add some exception catching mechanism in apply_bpe.py.

On IWSLT: I got BLEU=30.41 on IWSLT with 4 GPUs.

python PyFairseq/train.py data/output --save-dir local/train/model -s de -t en --arch fconv_iwslt_de_en --max-tokens 1000 --dropout 0.2 --lr 0.25 --clip-norm 0.1 --momentum 0.99
Namespace(arch='fconv_iwslt_de_en', clip_norm=0.1, data='data/output', decoder_attention='True', decoder_embed_dim=256, decoder_layers='[(256, 3)] * 3', decoder_out_embed_dim=256, dropo
ut=0.2, encoder_embed_dim=256, encoder_layers='[(256, 3)] * 4', force_anneal=0, label_smoothing=0, log_interval=1000, lr=0.25, lrshrink=0.1, max_epoch=0, max_positions=1024, max_tokens=
1000, min_lr=1e-05, model='fconv', momentum=0.99, no_epoch_checkpoints=False, no_progress_bar=False, no_save=False, restore_file='checkpoint_last.pt', sample_without_replacement=0, save
_dir='local/train/model', save_interval=-1, seed=1, source_lang='de', target_lang='en', test_subset='test', train_subset='train', valid_subset='valid', weight_decay=0.0, workers=1)
| [de] dictionary: 21577 types
| [en] dictionary: 16051 types
| data/output train 160215 examples
| data/output valid 7282 examples
| data/output test 6750 examples
| using 4 GPUs (with max tokens per GPU = 1000)

On WMT14 en-de: Yes, I got BLEU=25.35 on testset with 8GPUs.

edunov commented 6 years ago

Great, seems like you have reasonably good results on two other datasets, there are ways to push these numbers up, but for the default setup, this is what we expect to see.

For En2Fr, to the contrary, results are bad... When you get your first 3 epochs trained with LANG set to en_US.UTF-8, can you please try to generate and if the BLEU score comes below 37, paste a sample output of generate.py here. Also, please paste your train and valid losses after each epoch. Hope we can deduce why it is not working

Zrachel commented 6 years ago

Thank you. I'm working on it.

Zrachel commented 6 years ago

Hello @edunov , please take a look at my last reply in ISSUE https://github.com/facebookresearch/fairseq-py/issues/41 before reading the results below.

I have modified the encoding to en_US.UTF-8, and use with open(f, 'r', encoding='utf-8') as fd: in dictionary.py. After training one epoch, I get the following result:

(result of generate.py:)
Generate test with beam=5: BLEU4 = 33.19, 60.1/38.7/27.1/19.2 (BP=1.000, ratio=0.987, syslen=95426, reflen=94221)
(result of score.py:)
checkpoint1.pt/test.bleu:BLEU4 = 30.20, 59.5/36.0/23.9/16.2 (BP=1.000, ratio=0.975, syslen=83253, reflen=81194)

It still looks worse for such results. Here is a sample generated by checkpoint1.pt:

S-39    cr@@ anes arrived on the site just after 10@@ am , and traffic on the main road was diver@@ ted after@@ wards .
T-39    des gr@@ ues sont arriv@@ ées sur place peu après 10 heures , et la circulation sur la nationale a été détour@@ née dans la fou@@ lée .
H-39    -0.34878697991371155    Les gr@@ ues arriv@@ èrent sur le site juste après 10@@ h , et le trafic sur la route principale a été détour@@ né par la suite .
edunov commented 6 years ago

@Zrachel I checked the same sentence on my side after the first epoch (remember, I'm using --remove-bpe in generate.py so my sentences have BPE encoding removed):

S-355 Cranes arrived on the site just after 10am , and traffic on the main road was diverted afterwards . T-355 Des grues sont arrivées sur place peu après 10 heures , et la circulation sur la nationale a été détournée dans la foulée . H-355 -0.4027450680732727 Les grues sont arrivées sur le site juste après 10h , et le trafic sur la route principale a été détourné par la suite .

One thing that stands out is that you seem to have everything in test set lower-cased, while in train set you clearly have capital letters (e.g. your hypo starts with "Les"). We do not use lower casing in our training at all, so that might be a reason for the difference you observe.

edunov commented 6 years ago

Also, can you please report training and validation loss that you observe after first epoch? (You can find them in the training log)

Zrachel commented 6 years ago

My fault. I once removed the lowercase operation for training data, but forgot to remove in test data. Thank you very much.

Result on corrected testset:

checkpoint1.pt/test.bleu:BLEU4 = 35.73, 64.2/41.8/29.2/20.8 (BP=1.000, ratio=0.991, syslen=81952, reflen=81194)
checkpoint2.pt/test.bleu:BLEU4 = 37.20, 65.2/43.3/30.7/22.2 (BP=0.999, ratio=1.001, syslen=81142, reflen=81194)

Training and validation loss:

...
| epoch 001 | train loss 2.24 | train ppl 4.73 | s/checkpoint 77473 | words/s 16713 | words/batch 31228 | bsz 856 | lr 1.25 | clip 18% | gnorm 0.0936811 
| epoch 001 | valid on 'valid' subset | valid loss 1.75 | valid ppl 3.37 
...
| epoch 002 | train loss 1.76 | train ppl 3.38 | s/checkpoint 78280 | words/s 16541 | words/batch 31228 | bsz 856 | lr 1.25 | clip 0% | gnorm 0.0558167 
| epoch 002 | valid on 'valid' subset | valid loss 1.63 | valid ppl 3.11 
dagarcia-nvidia commented 6 years ago

Hi @Zrachel,

If you could please upload your fixed dataset that would help me a lot. We are currently using the dataset you uploaded earlier and we are running into the same problem.

Zrachel commented 6 years ago

Hi @dagarcia-nvidia , here: https://drive.google.com/open?id=1bFMhfhhMhhedPAPo0TDWBfVga8dFuTE1

dagarcia-nvidia commented 6 years ago

Thank you @Zrachel! That seems to solve the problem. Much appreciated!! :)