lucidrains / FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
MIT License
344 stars 24 forks source link

Speed on TPU #6

Closed magicknight closed 1 year ago

magicknight commented 1 year ago

Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

wangyuxin87 commented 1 year ago

I have the same question, GAU is slower than the original MHSA in my implementation, 3.5s vs 0.7s. As I simply use "from flash_pytorch import GAU" with the default setting.