Open blefaudeux opened 2 years ago
just had a look, I think that this would need a triton kernel to do it properly, the current pytorch implementation works but the speed and memory use are horrible (it's just not an operation which is easy to express with matmuls)
see https://arxiv.org/pdf/2006.16236.pdf, for reference
I've a proper causal product implemented in Triton on https://github.com/facebookresearch/xformers/tree/better_favor_with_triton, but looks like the same Triton bug as the one in https://github.com/facebookresearch/xformers/pull/162 is being triggered :( (and won't be fixed soon)
🚀 Feature
Lower favor+causal memory consumption
Motivation
Using a lot of memory for an approximation kind of defeats the purpose..
Pitch
would make favor more useable for NLP, not sure how much of a priority this is
Alternatives
Not doing anything, there are other options for causal approximations (Nystrom for instance)
Additional context
The previous implementation (prior to #104) was more memory efficient but not trainable, since the variables were modified in place