lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
625 stars 46 forks source link

stable_softmax #16

Closed huu4ontocord closed 1 year ago

huu4ontocord commented 1 year ago

Hi - I assume by the name stable_softmax is more stable than regular softmax. less inf or nan? if so, is this generally applicable to other libs? i've seen people convert to torch32 for more stability than convert back to original dtype.

lucidrains commented 1 year ago

@ontocord oh, that was actually because i had a misconception about pytorch's softmax (did not know they subtract out the max)

removed it for clarity!

and yes, you are right that some people like to do everything in float32. as we move towards bfloat16, this will become a non-issue

huu4ontocord commented 1 year ago

Awesome! I am testing the memorizing transformer code now. I'd like to incorporate some of it into the longllama code since they didn't implement the knn stuff.

lucidrains commented 1 year ago

nice! i prepared the tanh gating for that exact reason 😄 good luck!