idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Queries scaling is not consistent for recurrent wrappers #80

Closed hadaev8 closed 3 years ago

hadaev8 commented 3 years ago

https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/full_attention.py#L66

https://github.com/idiap/fast-transformers/blob/master/fast_transformers/recurrent/attention/self_attention/full_attention.py#L64

angeloskath commented 3 years ago

Hi,

They are mathematically equivalent. There is also a test for the equivalence of the recurrent and full versions precisely to avoid such a regression (https://github.com/idiap/fast-transformers/blob/2fe048a14c2e67787f553e899123ca4ba9f27e76/tests/recurrent/attention/self_attention/test_full_attention.py#L54).

Simply put:

a : scalar
q : vector
k : vector

(a q)^T k = a (q^T k)

I will close the issue but feel free to reopen it if you have a test case that shows they are not equivalent.

Cheers, Angelos

hadaev8 commented 3 years ago

I also thought it was equivalent, but somehow got different results

https://colab.research.google.com/drive/1gKOyl_gMYNBRqKzE2RxmN-T-dftAqXUN?usp=sharing

angeloskath commented 3 years ago

Well equality is hard to get in floating point numbers anyway :-) , but even torch.allclose() might be too strict. I tend to check the maximum absolute error and accept it if it is less 1e-5. In your colab the maximum absolute error was 2e-6.

Interestingly enough, if we multiply with QK the max abs error is 0 which makes sense since we have exactly the same over and underflows.

Thanks for checking things anyway. And if you ever do find a test case that fails share it so we can fix it and then add it to the tests.

Angelos