Smerity / sha-rnn

Single Headed Attention RNN - "Stop thinking with your head"
1.18k stars 134 forks source link

another implementation + partial reproduction #17

Open lunixbochs opened 3 years ago

lunixbochs commented 3 years ago

Thanks for the great paper!

I've created another open source implementation of the SHA RNN here: https://github.com/talonvoice/sha-rnn

I trained with similar parameters to the single head model at the end of your readme and achieved a bpc of 1.113 on test with LAMB, slightly worse than your 1.077, but still better than the Mogrifier LSTM. My epochs were also 1h instead of 30m. I used pytorch MultiHeadAttention for now instead of reimplementing the custom attention, which might be the reason for the different speed and bpc.

I have some notes in my README about the effort. It's possible I made some mistakes in the model or training as nobody else has reviewed the code yet.

Smerity commented 3 years ago

Great work! I had a quick skim over the code and like some of the refactors :)

After your initial training run did you repeat training but with a decreased learning rate? Ah, just saw ReduceLROnPlateau. Potentially try training without the plateau behaviour until far later in? I've had issues where the reduction of learning rate were premature.

I also saw your feedforwards are 2048 rather than 4096 which could well be part of the difference.

I could imagine the built in PyTorch MultiHeadAttention may have been part of the bpc drop too. Many of the small choices in my codebase and the paper were about making the gradients flow as cleanly as possible which seem to be surprisingly important. This may also be the cause of the speed slowdown given the values pass through un-matmulled (i.e. only inexpensive layernorms or similar).

Regardless thanks for the reproduction! I'm sure with a tiny amount of targeted polish it'd match or beat my own perplexities =]

P.S. Checked out Talon and your various other projects - love the work!

lunixbochs commented 3 years ago

Just coming back to this now. I did some ablation testing with both of our codebases and swapping basically every combination of the components made almost no difference in speed. So I realized I made a few mistakes 🤦. I'll maybe have some new results soon.

Other differences: I'm not using warmup or SplitCrossEntropyLoss (Adaptive Softmax) yet, though adaptive shouldn't matter much on the char model.

(Also, have you looked into SRU++?)