Open mkroutikov opened 5 years ago
Thanks for running this from scratch and I'm glad you're already almost to replicating it! I'll be able to get help you get the rest of the way as the mistakes are issues in the commands I've included in the README :)
The faster training time is as "the single headed SHA-LSTM is that each epoch took almost exactly 1800 ± 1 seconds (30 minutes) compared to the 4 headed SHA-LSTM which took 4020 seconds (67 minutes)" and the code currently runs the single-headed SHA-LSTM. That maps fairly well to your 5493 batches * 0.27 seconds per batch ~= 25 minutes per epoch.
If you'd like to use the full 4 headed SHA-LSTM (which requires a batch size of 8 on the Titan V but gets a slightly better result as noted in the paper but twice as slow - your V100 may be able to get a larger batch though!) you can set use_attn
to True for all layers on this line. Sorry for not making that a command line flag - I was setting options internally manually.
The commands supplied were originally in reference to the 4 layer SHA-LSTM where each layer contains an attention mechanism, not the single headed SHA-LSTM. The batch size 16 model requires a few extra epochs as there are less training steps per epoch compared to the batch size 8 model. As such the limit of 14 epochs for training is incorrect and was in reference to the full 4 layer SHA-LSTM (i.e. reproducing Figure 3 of the paper) which only used 19 epochs total - 16 using lr=2e-3
and 3 using lr=1e-3
.
For reproducing with the single headed SHA-LSTM, training with the first command (lr=2e-3
) should continue until the validation perplexity stops improving. You can leave the model running longer too as it'll only save models with better validation perplexity than previously discovered.
Training with the second command (lr=1e-3
) generally only takes two or three epochs until validation perplexity stops improving. The majority of the gain is usually in the first epoch.
If you still have the log you can run grep "valid bpc"
and we can verify whether the model was plateauing at epoch 14 using the learning rate of 2e-3. As an example below is a log for the single headed SHA-LSTM showing that there was still a bit of improvement to be made at epoch 14. It was my error in not updating the commands to have higher epoch counts when transitioning models.
| end of epoch 1 | time: 1813.52s | valid loss 0.92 | valid ppl 2.52 | valid bpc 1.334
| end of epoch 2 | time: 1811.97s | valid loss 0.87 | valid ppl 2.38 | valid bpc 1.252
| end of epoch 3 | time: 1815.17s | valid loss 0.84 | valid ppl 2.32 | valid bpc 1.214
| end of epoch 4 | time: 1813.15s | valid loss 0.82 | valid ppl 2.28 | valid bpc 1.189
| end of epoch 5 | time: 1812.83s | valid loss 0.81 | valid ppl 2.25 | valid bpc 1.171
| end of epoch 6 | time: 1815.57s | valid loss 0.80 | valid ppl 2.23 | valid bpc 1.160
| end of epoch 7 | time: 1809.98s | valid loss 0.80 | valid ppl 2.22 | valid bpc 1.149
| end of epoch 8 | time: 1806.74s | valid loss 0.79 | valid ppl 2.21 | valid bpc 1.142
| end of epoch 9 | time: 1814.90s | valid loss 0.79 | valid ppl 2.20 | valid bpc 1.138
| end of epoch 10 | time: 1805.94s | valid loss 0.79 | valid ppl 2.19 | valid bpc 1.134
| end of epoch 11 | time: 1803.10s | valid loss 0.78 | valid ppl 2.19 | valid bpc 1.129
| end of epoch 12 | time: 1800.77s | valid loss 0.78 | valid ppl 2.18 | valid bpc 1.125
| end of epoch 13 | time: 1801.50s | valid loss 0.78 | valid ppl 2.18 | valid bpc 1.123
| end of epoch 14 | time: 1797.74s | valid loss 0.78 | valid ppl 2.17 | valid bpc 1.120
| end of epoch 15 | time: 1799.41s | valid loss 0.78 | valid ppl 2.18 | valid bpc 1.121
| end of epoch 16 | time: 1796.43s | valid loss 0.77 | valid ppl 2.17 | valid bpc 1.117
| end of epoch 17 | time: 1797.27s | valid loss 0.77 | valid ppl 2.17 | valid bpc 1.118
| end of epoch 18 | time: 1798.39s | valid loss 0.77 | valid ppl 2.17 | valid bpc 1.116
| end of epoch 19 | time: 1799.21s | valid loss 0.77 | valid ppl 2.17 | valid bpc 1.114
| end of epoch 20 | time: 1799.06s | valid loss 0.77 | valid ppl 2.17 | valid bpc 1.115
| end of epoch 21 | time: 1799.13s | valid loss 0.77 | valid ppl 2.16 | valid bpc 1.114
| end of epoch 22 | time: 1799.01s | valid loss 0.77 | valid ppl 2.16 | valid bpc 1.113
| end of epoch 23 | time: 1803.80s | valid loss 0.77 | valid ppl 2.16 | valid bpc 1.112
| end of epoch 24 | time: 1804.56s | valid loss 0.77 | valid ppl 2.16 | valid bpc 1.112
| end of epoch 25 | time: 1806.76s | valid loss 0.77 | valid ppl 2.16 | valid bpc 1.112
| end of epoch 26 | time: 1807.95s | valid loss 0.77 | valid ppl 2.17 | valid bpc 1.115
| end of epoch 27 | time: 1805.45s | valid loss 0.77 | valid ppl 2.16 | valid bpc 1.112
That set of 27 epochs is about 13.5 hours of compute and gets to approximately your number. I killed training at that stage as the validation perplexity stopped improving. Then I resumed with lr=1e-3
and also the same batch size (16). Unfortunately the batch size 8 that you did was as I hadn't updated the README properly sorry -_-
| end of epoch 1 | time: 1799.11s | valid loss 0.76 | valid ppl 2.14 | valid bpc 1.100
| end of epoch 2 | time: 1810.37s | valid loss 0.76 | valid ppl 2.14 | valid bpc 1.100
| end of epoch 3 | time: 1814.39s | valid loss 0.76 | valid ppl 2.14 | valid bpc 1.100
| end of epoch 4 | time: 1801.80s | valid loss 0.76 | valid ppl 2.15 | valid bpc 1.102
| end of epoch 5 | time: 1800.42s | valid loss 0.76 | valid ppl 2.14 | valid bpc 1.100
As noted all the benefit in the case above is on the first epoch. Sometimes the drop is over a few epochs. There are likely better ways to decay the learning rate but I hadn't explored them.
The above model resulted in a test bpc of 1.078.
If you have spare compute and want to try it again, do so and get back to me. Otherwise I'll be repeating the experiment myself overnight (... it's 5am but I'll pretend it's overnight lol ...) and report back. Do also note if you have a preference for the faster model or the slightly better model but slower / heavier model. I might have gone the wrong direction by setting the slightly faster one as the default for the codebase.
Thanks for your experiment! ^_^
Thanks for the quick and detailed reply. I was foolish enough to run this in terminal, so the logs are mostly lost to the limited terminal scrolling window :(. As I watched the first run, I am pretty sure it plateaued at the last few epochs. Validation bpcs were a bit lagging, but fairly close to those plotted in the article.
Yes, makes sense to do the batch of 16 on the second run and just few epochs. If I have time I'll repeat the experiment and report here.
Love the speed and high GPU utilization. Thank you for publishing this. Not everyone has 1K of TPUs, this thing gives us poor guys some hope ;)
| epoch 4 | 5580/ 5493 batches | lr 0.00100 | ms/batch 265.68 | loss 0.70 | ppl 2.02 | bpc 1.012
| epoch 4 | 5590/ 5493 batches | lr 0.00100 | ms/batch 253.69 | loss 0.68 | ppl 1.96 | bpc 0.974
| epoch 4 | 5600/ 5493 batches | lr 0.00100 | ms/batch 268.03 | loss 0.71 | ppl 2.04 | bpc 1.028
| epoch 4 | 5610/ 5493 batches | lr 0.00100 | ms/batch 250.21 | loss 0.70 | ppl 2.02 | bpc 1.011
| epoch 4 | 5620/ 5493 batches | lr 0.00100 | ms/batch 258.41 | loss 0.69 | ppl 1.99 | bpc 0.989
| epoch 4 | 5630/ 5493 batches | lr 0.00100 | ms/batch 265.86 | loss 0.70 | ppl 2.02 | bpc 1.013
-----------------------------------------------------------------------------------------
| end of epoch 5 | time: 1481.99s | valid loss 0.76 | valid ppl 2.14 | valid bpc 1.099
-----------------------------------------------------------------------------------------
Model total parameters: 53790932
=========================================================================================
| End of training | test loss 0.75 | test ppl 2.11 | test bpc 1.077
=========================================================================================
I did this:
lr=2e-3
and batch size 16
for 2 epochs (to shake it up a bit)lr=1e-3
and batch size 16
for 5 epochs (it took 2 epochs to get to the best validation, rest was waste). The latest trained model can be downloaded from here.
Summary: all works as advertised!
Above is the output at the end of the second training run, as in the README.
My setup:
Trained model is here (205Mb)
Other notes: