tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.62k stars 3.51k forks source link

Bug introduced in v1.3.0 causing training divergence #529

Open martinpopel opened 6 years ago

martinpopel commented 6 years ago

I train with exactly the same setup (1GPU, transformer_big_single_gpu, batch_size=1500, en-cs dataset with 58M training sentences, checkpoints saved each hour) with T2T v1.2.9 and v1.3.0 (which is the version immediately following 1.2.9). While in v1.2.9 everything is OK, in v1.3.0 the training diverges after about one hour of training and BLEU drops to almost zero after about 5 hours of training. See the learning curves below (orange=v1.2.9, magenta=1.3.0, y-axis is dev-set BLEU as measured by t2t-bleu, x-axis is training time in hours): v129-vs-v130

I attach also training loss curves for these two experiments (but now orange=v1.3.0, gray=v1.2.9, so the colors don't match the previous graph, x-axis is now in steps): v129-vs-v130-train-loss

martinpopel commented 6 years ago

I did the same experiment as in the previous post, but now with a bigger batch size: 2200 instead of 1500. (I had to add max_length=150 so I can afford such batch size.) Again, the setup is exactly the same for v1.2.9 and v1.3.0.

First the dev-set BLEU: v129-vs-v130-b2200-bleu V 1.3.0 (organge) is much worse than v1.2.9 (gray) after two hours of training, but this time the BLEU is not "almost-zero", it continues to grow, just more slowly. This is even more dangerous as in the previous post because most users won't notice there is a problem, they will just get lower final BLEU. (Maybe after training long enough, the two curves will join, but anyway, it is suboptimal.)

And the training loss: v129-vs-v130-b2200-train-loss v1.3.0 has notably higher training loss than v1.2.9 after 3h (20k steps) of training.

mehmedes commented 6 years ago

This to confirm @martinpopel's observance.

The orange curve was trained on T2T 1.4.1.

The other curves belong to to 1.4.2. I tried increasing warmup to 16k, 32k or decreasing learning rate to 0.05, the model seems to always diverge:

image

image

martinpopel commented 6 years ago

@mehmedes: so you see a difference between T2T 1.4.1. and 1.4.2? So perhaps that's another issue. I haven't noticed any difference between these versions (for my translation problem) and also the changes related to transformer (i.e. not transformer_vae) in #506 seem just cosmetic (but I may be wrong). I saw a difference between 1.2.9 and 1.3.0 and there are several major changes in #449 which may be the culprit.

I tried increasing warmup to 16k, 32k [...] the model seems to always diverge

In my case, increasing warmup steps from 16k to 32k helped - after one day of training (170k steps), the curves are basically the same (even now after 3 days of training, not shown on the picture below): v129-vs-v142-b2200-warmup Gray is the original v1.2.9 with warmup=16k. Blue is v1.4.2 (but v1.3.0 gives the same curve) with warmup=16k. Orange is v1.4.2 with the increased warmup=32k.

However, all my experiments with the new version are worse than v1.2.9 in the first 12 hours of training. E.g. increasing warmup to 48k does not help (it starts similarly to the blue curve and joins the gray curve only after two days of training).

lukaszkaiser commented 6 years ago

Should we still try to pinpoint which exact change from 1.2.9 to 1.3.0 caused this effect? Or should we just increase warmup_steps in single_gpu configs? There doesn't seem to be that much that changed from 1.2.9 to 1.3.0, I'm surprised it causes this! Do you have any suggestions what could have caused it? We added target_weights_fn, but it doesn't look like we forgot padding. Otherwise, did anything change? Can you decode the 1.2.9-trained model with 1.3.0? Did the weights change in any way?

martinpopel commented 6 years ago

Can you decode the 1.2.9-trained model with 1.3.0?

No because the checkpoint format is not compatible (because of 01b8c31da30a7e1109451df2b4b4698946c6c35c and c10e0160e1bd00c68568df1ca80e5cbdd2c81a3b). I thought it was 1.3.0 where tpu_trainer.py (later renamed to trainer_lib.py) was used as the default one also when training on GPU, but I may be wrong. I am quite busy now, but I may try git bisect the exact commit responsible for the change.

vince62s commented 6 years ago

like @mehmedes I really think there is an issue with 1.4.2 vs 1.4.1 both on ENFR and ENDE I see a big gap

martinpopel commented 6 years ago

After training more versions for longer time, it seems there is not a single commit culprit, but the issue is more complicated. See training loss curves: v129-vs-various-train-loss blue = v1.4.2 has clearly the worst training loss magenta = v1.3.0 is a bit better, but still the training does not converge gray = c0ce3dd which is a commit between v1.2.9 and v.1.3.0, again a bit better, but still not converging the rest = v1.2.9, 6cf47f9 and 936db05, all these three commits are OK

I have problems plotting the BLEU curves because most of the versions have broken/missing t2t-bleu and t2t-translate-all scripts, so I would have to patch these versions first (some commits between v1.2.9 and v.1.3.0 cannot be decoded with neither of these versions, which I have already patched). Another problem is that until recently t2t had not fixed properly rand seed, but even so, I think the difference between the bottom three "converging" curves and the three "diverging curves" is too high to be caused by the random initialization.

martinpopel commented 6 years ago

I found the main culprit: it is commit https://github.com/tensorflow/tensor2tensor/pull/449/commits/0ffe0e6654366d700b8850d6423a717083586d12. The training-loss graph below shows this commit's curve in green (it has almost the same curve as v1.3.0 in my previous post). The blue curve is the previous commit (398e85b) and it has almost the same curve as v1.2.9 in my previous post. commit-0ffe0e6-effect

martinpopel commented 6 years ago

T2T before v1.4.2 had not fixed rand seed, so I re-ran some experiments several times. The culprit commit 0ffe0e6 (and any later commit) gives always bad results (training diverges as the green curve in my previous post). Older commits (after v1.2.9 and before 0ffe0e6) give the ideal result most of the time, but sometimes the same commit which was OK in one experiment, results in a higher training loss in another run, as the gray curve in https://github.com/tensorflow/tensor2tensor/issues/529#issuecomment-360786725. However, these random worsenings still have training loss clearly below the buggy 0ffe0e6.

martinpopel commented 6 years ago

@vince62s suggested the worsening may be caused by (probably unintentional) removal of bias variables in 0ffe0e6. common_attention.py originally used function common_layers.conv1d with the default use_bias=True. 0ffe0e6 changed this to tf.layers.dense which also adds the bias variables by default, but this default has been overridden to use_bias=False in all calls. I've change it to use_bias=True, but it had no effect, the training loss still starts growing after 16k steps, exactly as with use_bias=False. So I still don't know which code change causes the worsening. Nevertheless, it is possible that the change of use_bias was not intentional (at least it is not mentioned in the commit log, which says just "Make Transformer fast on TPU."). @nshazeer @lukaszkaiser: Can you check whether use_bias=False is what you want and explain why the bias was dropped?

nshazeer commented 6 years ago

I intentionally removed the bias, since it didn't seem to have any effect on quality, and having a smaller number of variables was faster on some systems. If it turns out to be important, we can put it back.

On Thu, Feb 1, 2018 at 8:32 AM, Martin Popel notifications@github.com wrote:

@vince62s https://github.com/vince62s suggested the worsening may be caused by (probably unintentional) removal of bias variables in 0ffe0e6 https://github.com/tensorflow/tensor2tensor/commit/0ffe0e6654366d700b8850d6423a717083586d12 . common_attention.py originally used function common_layers.conv1d with the default use_bias=True. 0ffe0e6 https://github.com/tensorflow/tensor2tensor/commit/0ffe0e6654366d700b8850d6423a717083586d12 changed this to tf.layers.dense which also adds the bias variables by default, but this default has been overridden to use_bias=False in all calls. I've change it to use_bias=True, but it had no effect, the training loss still starts growing after 16k steps, exactly as with use_bias=False. So I still don't know which code change causes the worsening. Nevertheless, it is possible that the change of use_bias was not intentional (at least it is not mentioned in the commit log, which says just "Make Transformer fast on TPU."). @nshazeer https://github.com/nshazeer @lukaszkaiser https://github.com/lukaszkaiser: Can you check whether use_bias=False is what you want and explain why the bias was dropped?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/tensor2tensor/issues/529#issuecomment-362322253, or mute the thread https://github.com/notifications/unsubscribe-auth/AcZ97hczleHC9lL-zIS5t07k3NGJa59Nks5tQec5gaJpZM4Rk3iD .

martinpopel commented 6 years ago

@nshazeer: Thank you for the info.

If it turns out to be important, we can put it back.

No, the bias has no effect according to my experiments. But something else in the same commit can have a disastrous effect in some circumstances (as described in this issue).

rsepassi commented 6 years ago

Does this problem also occur with transformer_base? I ran base recently on translate_ende_wmt32k and it hit neg log perplexity of -1.533 after 250k steps. Unfortunately I don't have the training curves. Thanks very much for your investigations @martinpopel. We're also internally adding more models and configurations to regression tests so we can catch these things sooner.

rsepassi commented 6 years ago

Actually, here's the loss curve for that run.

lossbase

martinpopel commented 6 years ago

All the experiments reported here are with transformer_big_single_gpu on a en-cs dataset with 58M training sentences. I have not tried to check this error with other models or datasets. The error can be workaround with a higher batch size or higher warmup, so it is probable that with different models or datasets, the "danger zone" of too low batch size and/or too low warmup and/or too high learning rate is located elsewhere, not near the default values, so the error won't be noticed.

If there are no more ideas what caused the error (in 0ffe0e6), we can close this issue.

adding more models and configurations to regression tests so we can catch these things sooner.

Great. Thanks.

vince62s commented 6 years ago

I confirm that on 1.4.3 (with patch for the correct learning rate schedule) the ENDE transformer_base runs fine on 4GPU (batch size 5430 x4) => 26.62 Still a little bit off vs the paper though (27.3)

rsepassi commented 6 years ago

Thanks @vince62s. That's good news.

It may be that we need to retune some of the hparams sets and update them.

In general it seems to be very difficult to maintain reproducibility without freezing a codebase/codepath entirely.

What we may want to consider doing is a hard fork of the entire T2TModel.estimator_model_fn and hparams set for every publication. Copy in all the code, put the date in the model name and hparams set, and then continue development on the main model separately.

Open to suggestions and discussion here. This is a hard problem.

rsepassi commented 6 years ago

We could also just say that to reproduce an experiment you have to be on a certain commit hash.

vince62s commented 6 years ago

If I may suggest something: When a version seems stable and gives reproducible results, keep that release in a branch and make clear to people that master is a dev branch with potential bugs, OR may change a little bit your workflow. I can tell it is not easy to keep track from a user / light developper stand point beacsue you guys dump big commits once every 2 or 3 or 4 weeks and then to retrieve where things have really changed in so many commits, it's very very difficult. I have discussed this with @martinpopel since we are 2 fans and I think we a little bit on the same page. having said that: great work and very happy to use this in opensource :) Vincent

rsepassi commented 6 years ago

Yeah, I do think it's reasonable to tag certain commits as the golden commits to reproduce a certain baseline (model-problem pair). Development can continue on master. So maybe for each paper publication or model-problem pair of interest, we tag a commit, and in the tag description we include a command line that reproduces it perfectly. That would be quite nice.

stefan-it commented 6 years ago

Hi,

I'm not sure if this is the right place (if not I can open a new issue for that), but I trained a transformer_big model on translate_enmk_setimes32k_rev for ~ 100k epochs steps on a DGX-1.

The resulting BLEU-score was 8- 9! When I initially submitted the enmk dataset the achieved BLEU-score with the "smaller" transformer model was 52, see #158.

Here are some images from TensorBoard:

tensorboard

Should I try an older version like 1.2.9 to check if this low BLEU-score is reproducable?

Here's the BLEU-score graph from an older tensor2tensor version (#158) with the smaller transformer model:

dev-bleu

stefan-it commented 6 years ago

Here are the results using version 1.2.9. I trained for 34k epochs steps and the BLEU-score is 52 (development set):

1 2 9

martinpopel commented 6 years ago

@stefan-it: Thanks for sharing your results. If you want to confirm it is the same bug as reported by myself try commit 0ffe0e6 and the commit before. If you want to use the newest version, try one (or more) of the strategies to overcome diverged training:

stefan-it commented 6 years ago

@martinpopel thanks for that hint. I trained another model with 1.5.1 using clip_grad_norm=1.0 for 20k epochs steps and compared the results with version 1.2.9 and 1.5.1 (with default parameters).

Here are the results:

Step BLEU with 1.2.9 BLEU with 1.5.1 (default) BLEU with 1.5.1 (clip_grad_norm)
1000 0.1746 0.0435 0.0522
2000 0.2665 0.1394 0.1450
3000 0.3466 0.2480 0.2495
4000 0.4591 0.3134 0.3254
5000 0.4666 0.3512 0.3581
6000 0.4724 0.3663 0.3764
7000 0.4863 0.3678 0.3909
8000 0.4998 0.3394 ↓ 0.3821 ↓
9000 0.5076 0.3004 0.3737
10000 0.5191 0.2580 0.3694
11000 0.5471 0.2544 0.3573
12000 0.5304 0.2329 0.3358
13000 0.5251 0.2257 0.3292
14000 0.5354 0.2030 0.3148
15000 0.5425 0.2073 0.3035
16000 0.5297 0.2111 0.2952
17000 0.5330 0.2097 0.2116
18000 0.5598 0.2115 0.2699
19000 0.5451 0.1972 0.2646
20000 0.5248 0.1948 0.2525
21000 0.5564 0.1956 0.2601

clip_grad_norm=1.0 has good impact on BLEU-score compared to the default setting.

But compared to version 1.2.9 there's a difference of 29,63 points for BLEU.

vince62s commented 6 years ago

IMO it does not make sense to train a big model with 200K sentences. Maybe try another dataset and see if this behavior is replicated.

martinpopel commented 6 years ago

Yes, if the data is translate_enmk_setimes32k with just 205,777 training sentences, then I guess using transformer_big (or transformer_big_single_gpu which gives me better results even when used with 8GPUs) is an overkill and transformer_base_single_gpu may give better results (or at least faster). BTW: OPUS contains 7.5M mk-en sentences, although biased towards the subtitles "domain", so domain adaptation may be needed if the target domain is news.

However, no matter what training data and hyperparams set @stefan-it used, it would be interesting to confirm whether the drop in BLEU between 1.2.9 and 1.5.1 is caused by the same commit as I found. Just to be sure, I would suggest to re-evaluate the BLEU results (at least the final ones) with sacreBLEU or t2t-bleu to have trustworthy results (with approx_bleu there is a risk that also the non/autoregressive slow/fast implementation changes between versions influence the results).

And finally a bit of terminological nitpicking: what @stefan-it calls epochs is usually called steps in T2T (or iterations or number of updates). Epoch is usually understood as one pass over the whole training data, see #415.

stefan-it commented 6 years ago

Here are the BLEU scores using t2t-bleu for the model trained with version 1.2.9 and 1.5.1 (with clip_grad_norm=1.0) after 21k steps:

BLEU 1.2.9 1.5.1 (clip_grad_norm=1.0)
uncased 43.74 20.22
cased 42.93 19.78

So I think this problem here can be replicated by at least two datasets with different sizes.