For full details see the paper Single Headed Attention RNN: Stop Thinking With Your Head.
In summary, "stop thinking with your (attention) head".
Model | Test BPC | Params | LSTM Based |
---|---|---|---|
Krause mLSTM | 1.24 | 46M | ✔ |
AWD-LSTM | 1.23 | 44M | ✔ |
SHA-LSTM | 1.07 | 63M | ✔ |
12L Transformer-XL | 1.06 | 41M | |
18L Transformer-XL | 1.03 | 88M | |
Adaptive Span Transformer (Small) | 1.02 | 38M |
Whilst the model is still quite some way away from state of the art (~0.98 bpc) the model is low resource and high efficiency without having yet been optimized to be so. The model was trained in under 24 hours on a single GPU with the Adaptive Span Transformer (small) being the only recent Transformer model to achieve similar levels of training efficiency.
To get started:
./getdata.sh
By default the model trains the minimal single headed attention model from the paper, inserting a lone attention mechanism in the second last layer of a four layer LSTM.
This takes only half an hour per epoch on a Titan V or V100.
If you want slightly better results but a longer training time (an hour per epoch) set use_attn
to True for all layers in model.py
and decrease batch size until it fits in memory (i.e. 8).
Sadly there are no command line options for running the other models - it's manual tinkering.
The code is not kind.
I'll be performing a re-write in the near future meant for long term academic and industrial use - contact me if you're interested :)
Note: still shaking out bugs from the commands below. We have near third party replication but still a fix or two out. Feel free to run and note any discrepancies! If you fiddle with hyper-parameters (which I've done very little of - it's a treasure chest of opportunity to get a lower than expected BPC as your reward!) do report that too :)
When running the training command below continue until the validation bpc stops improving. Don't worry about letting it run longer as the code will only save the model with the best validation bpc.
python -u main.py --epochs 32 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --data data/enwik8/ --save ENWIK8.pt --log-interval 10 --seed 5512 --optimizer lamb --bptt 1024 --warmup 800 --lr 2e-3 --emsize 1024 --nhid 4096 --nlayers 4 --batch_size 16
When the training slows down a second pass with a halved learning rate until validation bpc stops improving will get a few more bpc off. A smart learning rate decay is likely the correct way to go here but that's not what I did for my experiments.
python -u main.py --epochs 5 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --data data/enwik8/ --save ENWIK8.pt --log-interval 10 --seed 5512 --optimizer lamb --bptt 1024 --warmup 800 --lr 2e-3 --emsize 1024 --nhid 4096 --nlayers 4 --batch_size 16 --resume ENWIK8.pt --lr 1e-3 --seed 125
Most of the improvement will happen in the first few epochs of this final command.
The final test bpc should be approximately 1.07 for the full 4 layer SHA-LSTM or 1.08 for the single headed 4 layer SHA-LSTM.