lucidrains / memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
MIT License
620 stars 46 forks source link

stable_softmax #16

Closed huu4ontocord closed 12 months ago

huu4ontocord commented 12 months ago

Hi - I assume by the name stable_softmax is more stable than regular softmax. less inf or nan? if so, is this generally applicable to other libs? i've seen people convert to torch32 for more stability than convert back to original dtype.

lucidrains commented 12 months ago

@ontocord oh, that was actually because i had a misconception about pytorch's softmax (did not know they subtract out the max)

removed it for clarity!

and yes, you are right that some people like to do everything in float32. as we move towards bfloat16, this will become a non-issue

huu4ontocord commented 12 months ago

Awesome! I am testing the memorizing transformer code now. I'd like to incorporate some of it into the longllama code since they didn't implement the knn stuff.

lucidrains commented 12 months ago

nice! i prepared the tanh gating for that exact reason 😄 good luck!