Hi,
Thanks for the code!
I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.
I have the same question, GAU is slower than the original MHSA in my implementation, 3.5s vs 0.7s. As I simply use "from flash_pytorch import GAU" with the default setting.
Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.