Smerity / sha-rnn

Single Headed Attention RNN - "Stop thinking with your head"
1.18k stars 133 forks source link

Reproduced BPC of 1.077 using model with one attention layer #3

Open mkroutikov opened 4 years ago

mkroutikov commented 4 years ago
| epoch  13 | 11210/10986 batches | lr 0.00100 | ms/batch 188.01 | loss  0.71 | ppl     2.04 | bpc    1.029
| epoch  13 | 11220/10986 batches | lr 0.00100 | ms/batch 188.13 | loss  0.64 | ppl     1.90 | bpc    0.925
| epoch  13 | 11230/10986 batches | lr 0.00100 | ms/batch 193.46 | loss  0.73 | ppl     2.07 | bpc    1.052
| epoch  13 | 11240/10986 batches | lr 0.00100 | ms/batch 193.60 | loss  0.74 | ppl     2.10 | bpc    1.071
| epoch  13 | 11250/10986 batches | lr 0.00100 | ms/batch 193.58 | loss  0.71 | ppl     2.03 | bpc    1.021
| epoch  13 | 11260/10986 batches | lr 0.00100 | ms/batch 185.65 | loss  0.73 | ppl     2.07 | bpc    1.051
| epoch  13 | 11270/10986 batches | lr 0.00100 | ms/batch 194.74 | loss  0.72 | ppl     2.04 | bpc    1.032
| epoch  13 | 11280/10986 batches | lr 0.00100 | ms/batch 194.90 | loss  0.67 | ppl     1.95 | bpc    0.964
| epoch  13 | 11290/10986 batches | lr 0.00100 | ms/batch 180.78 | loss  0.73 | ppl     2.08 | bpc    1.057
| epoch  13 | 11300/10986 batches | lr 0.00100 | ms/batch 193.61 | loss  0.73 | ppl     2.07 | bpc    1.050
-----------------------------------------------------------------------------------------
| end of epoch  14 | time: 2183.27s | valid loss  0.76 | valid ppl     2.15 | valid bpc    1.101
-----------------------------------------------------------------------------------------
Model total parameters: 53790932
=========================================================================================
| End of training | test loss  0.76 | test ppl     2.15 | test bpc    1.102
=========================================================================================

Above is the output at the end of the second training run, as in the README.

My setup:

  1. V100 GPU (AWS p3.2xlarge)
  2. Pytorch 1.3.1
  3. apex 0.1
  4. Followed training instructions from README: ran training twice 14 + 14 epochs, second time with smaller batch size and learning rate.

Trained model is here (205Mb)

Other notes:

Smerity commented 4 years ago

Thanks for running this from scratch and I'm glad you're already almost to replicating it! I'll be able to get help you get the rest of the way as the mistakes are issues in the commands I've included in the README :)

The faster training time is as "the single headed SHA-LSTM is that each epoch took almost exactly 1800 ± 1 seconds (30 minutes) compared to the 4 headed SHA-LSTM which took 4020 seconds (67 minutes)" and the code currently runs the single-headed SHA-LSTM. That maps fairly well to your 5493 batches * 0.27 seconds per batch ~= 25 minutes per epoch.

If you'd like to use the full 4 headed SHA-LSTM (which requires a batch size of 8 on the Titan V but gets a slightly better result as noted in the paper but twice as slow - your V100 may be able to get a larger batch though!) you can set use_attn to True for all layers on this line. Sorry for not making that a command line flag - I was setting options internally manually.

The commands supplied were originally in reference to the 4 layer SHA-LSTM where each layer contains an attention mechanism, not the single headed SHA-LSTM. The batch size 16 model requires a few extra epochs as there are less training steps per epoch compared to the batch size 8 model. As such the limit of 14 epochs for training is incorrect and was in reference to the full 4 layer SHA-LSTM (i.e. reproducing Figure 3 of the paper) which only used 19 epochs total - 16 using lr=2e-3 and 3 using lr=1e-3.


For reproducing with the single headed SHA-LSTM, training with the first command (lr=2e-3) should continue until the validation perplexity stops improving. You can leave the model running longer too as it'll only save models with better validation perplexity than previously discovered.

Training with the second command (lr=1e-3) generally only takes two or three epochs until validation perplexity stops improving. The majority of the gain is usually in the first epoch.


If you still have the log you can run grep "valid bpc" and we can verify whether the model was plateauing at epoch 14 using the learning rate of 2e-3. As an example below is a log for the single headed SHA-LSTM showing that there was still a bit of improvement to be made at epoch 14. It was my error in not updating the commands to have higher epoch counts when transitioning models.

| end of epoch   1 | time: 1813.52s | valid loss  0.92 | valid ppl     2.52 | valid bpc    1.334
| end of epoch   2 | time: 1811.97s | valid loss  0.87 | valid ppl     2.38 | valid bpc    1.252
| end of epoch   3 | time: 1815.17s | valid loss  0.84 | valid ppl     2.32 | valid bpc    1.214
| end of epoch   4 | time: 1813.15s | valid loss  0.82 | valid ppl     2.28 | valid bpc    1.189
| end of epoch   5 | time: 1812.83s | valid loss  0.81 | valid ppl     2.25 | valid bpc    1.171
| end of epoch   6 | time: 1815.57s | valid loss  0.80 | valid ppl     2.23 | valid bpc    1.160
| end of epoch   7 | time: 1809.98s | valid loss  0.80 | valid ppl     2.22 | valid bpc    1.149
| end of epoch   8 | time: 1806.74s | valid loss  0.79 | valid ppl     2.21 | valid bpc    1.142
| end of epoch   9 | time: 1814.90s | valid loss  0.79 | valid ppl     2.20 | valid bpc    1.138
| end of epoch  10 | time: 1805.94s | valid loss  0.79 | valid ppl     2.19 | valid bpc    1.134
| end of epoch  11 | time: 1803.10s | valid loss  0.78 | valid ppl     2.19 | valid bpc    1.129
| end of epoch  12 | time: 1800.77s | valid loss  0.78 | valid ppl     2.18 | valid bpc    1.125
| end of epoch  13 | time: 1801.50s | valid loss  0.78 | valid ppl     2.18 | valid bpc    1.123
| end of epoch  14 | time: 1797.74s | valid loss  0.78 | valid ppl     2.17 | valid bpc    1.120
| end of epoch  15 | time: 1799.41s | valid loss  0.78 | valid ppl     2.18 | valid bpc    1.121
| end of epoch  16 | time: 1796.43s | valid loss  0.77 | valid ppl     2.17 | valid bpc    1.117
| end of epoch  17 | time: 1797.27s | valid loss  0.77 | valid ppl     2.17 | valid bpc    1.118
| end of epoch  18 | time: 1798.39s | valid loss  0.77 | valid ppl     2.17 | valid bpc    1.116
| end of epoch  19 | time: 1799.21s | valid loss  0.77 | valid ppl     2.17 | valid bpc    1.114
| end of epoch  20 | time: 1799.06s | valid loss  0.77 | valid ppl     2.17 | valid bpc    1.115
| end of epoch  21 | time: 1799.13s | valid loss  0.77 | valid ppl     2.16 | valid bpc    1.114
| end of epoch  22 | time: 1799.01s | valid loss  0.77 | valid ppl     2.16 | valid bpc    1.113
| end of epoch  23 | time: 1803.80s | valid loss  0.77 | valid ppl     2.16 | valid bpc    1.112
| end of epoch  24 | time: 1804.56s | valid loss  0.77 | valid ppl     2.16 | valid bpc    1.112
| end of epoch  25 | time: 1806.76s | valid loss  0.77 | valid ppl     2.16 | valid bpc    1.112
| end of epoch  26 | time: 1807.95s | valid loss  0.77 | valid ppl     2.17 | valid bpc    1.115
| end of epoch  27 | time: 1805.45s | valid loss  0.77 | valid ppl     2.16 | valid bpc    1.112

That set of 27 epochs is about 13.5 hours of compute and gets to approximately your number. I killed training at that stage as the validation perplexity stopped improving. Then I resumed with lr=1e-3 and also the same batch size (16). Unfortunately the batch size 8 that you did was as I hadn't updated the README properly sorry -_-

| end of epoch   1 | time: 1799.11s | valid loss  0.76 | valid ppl     2.14 | valid bpc    1.100
| end of epoch   2 | time: 1810.37s | valid loss  0.76 | valid ppl     2.14 | valid bpc    1.100
| end of epoch   3 | time: 1814.39s | valid loss  0.76 | valid ppl     2.14 | valid bpc    1.100
| end of epoch   4 | time: 1801.80s | valid loss  0.76 | valid ppl     2.15 | valid bpc    1.102
| end of epoch   5 | time: 1800.42s | valid loss  0.76 | valid ppl     2.14 | valid bpc    1.100

As noted all the benefit in the case above is on the first epoch. Sometimes the drop is over a few epochs. There are likely better ways to decay the learning rate but I hadn't explored them.

The above model resulted in a test bpc of 1.078.


If you have spare compute and want to try it again, do so and get back to me. Otherwise I'll be repeating the experiment myself overnight (... it's 5am but I'll pretend it's overnight lol ...) and report back. Do also note if you have a preference for the faster model or the slightly better model but slower / heavier model. I might have gone the wrong direction by setting the slightly faster one as the default for the codebase.

Thanks for your experiment! ^_^

pgmmpk commented 4 years ago

Thanks for the quick and detailed reply. I was foolish enough to run this in terminal, so the logs are mostly lost to the limited terminal scrolling window :(. As I watched the first run, I am pretty sure it plateaued at the last few epochs. Validation bpcs were a bit lagging, but fairly close to those plotted in the article.

Yes, makes sense to do the batch of 16 on the second run and just few epochs. If I have time I'll repeat the experiment and report here.

Love the speed and high GPU utilization. Thank you for publishing this. Not everyone has 1K of TPUs, this thing gives us poor guys some hope ;)

mkroutikov commented 4 years ago
| epoch   4 |  5580/ 5493 batches | lr 0.00100 | ms/batch 265.68 | loss  0.70 | ppl     2.02 | bpc    1.012
| epoch   4 |  5590/ 5493 batches | lr 0.00100 | ms/batch 253.69 | loss  0.68 | ppl     1.96 | bpc    0.974
| epoch   4 |  5600/ 5493 batches | lr 0.00100 | ms/batch 268.03 | loss  0.71 | ppl     2.04 | bpc    1.028
| epoch   4 |  5610/ 5493 batches | lr 0.00100 | ms/batch 250.21 | loss  0.70 | ppl     2.02 | bpc    1.011
| epoch   4 |  5620/ 5493 batches | lr 0.00100 | ms/batch 258.41 | loss  0.69 | ppl     1.99 | bpc    0.989
| epoch   4 |  5630/ 5493 batches | lr 0.00100 | ms/batch 265.86 | loss  0.70 | ppl     2.02 | bpc    1.013
-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 1481.99s | valid loss  0.76 | valid ppl     2.14 | valid bpc    1.099
-----------------------------------------------------------------------------------------
Model total parameters: 53790932
=========================================================================================
| End of training | test loss  0.75 | test ppl     2.11 | test bpc    1.077
=========================================================================================

I did this:

  1. Loaded the model trained overnight (see the link above). Trained it with lr=2e-3 and batch size 16 for 2 epochs (to shake it up a bit)
  2. Fine-tuned using lr=1e-3 and batch size 16 for 5 epochs (it took 2 epochs to get to the best validation, rest was waste).

The latest trained model can be downloaded from here.

Summary: all works as advertised!