lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.07k stars 143 forks source link

Bug fix in original google-research implementation #50

Closed gulnazaki closed 3 years ago

gulnazaki commented 3 years ago

Hey there,

I've seen that a significant bug regarding the data_normalizer has been recently fixed in the original implementation in case you haven't checked it yet. I see it exists here too, since you ported the code.

https://github.com/google-research/google-research/commit/b09ac837cd5720bc60f1c16b472a7ab462b0ddb8

lucidrains commented 3 years ago

@gulnazaki haha, I ported over their Jax code, so it should be fine :) that's their new tensorflow implementation

lucidrains commented 3 years ago

@gulnazaki thanks for letting me know!

btw, new follow-up paper for Performer! https://arxiv.org/abs/2012.11346

tldr: sorta-gradient checkpointing along the sequence dimension

gulnazaki commented 3 years ago

Oh yes, I am sorry I only had a quick look at it. I understand they fixed the tf implementation to match the one in jax.

Cool paper also, now the sky is the limit :smile: