aladdinpersson / Machine-Learning-Collection

A resource for learning about Machine learning & Deep Learning
https://www.youtube.com/c/AladdinPersson
MIT License
7.7k stars 2.7k forks source link

SelfAttention bug on Scores * V #165

Open huberemanuel opened 1 year ago

huberemanuel commented 1 year ago

Hey Aladdin, thanks for your tutorials!

I've been implementing the Transformer architecture and learning about einsum. Following your implementation (einsum) against one without einsum I found differences in the final result. Here is the code for reproducibility:

b, s, h, d = 2, 2, 2, 2
q = torch.randn((b, s, h, d))
k = torch.randn((b, s, h, d))
v = torch.randn((b, s, h, d))
q_mod = q.permute(0, 2, 1, 3) # [b, h, s, d]
k_mod = k.permute(0, 2, 3, 1) # [b, h, d, s]
classic_scores = torch.matmul(q_mod, k_mod)
classic_scores = torch.softmax(classic_scores / (d ** (1/2)), dim=3)
v_mod = v.permute(0, 2, 1, 3)
classic_att = torch.matmul(classic_scores, v_mod).reshape(b, s, h * d)

einstein_scores = torch.einsum("bqhd,bkhd->bhqk", q, k)
einstein_scores = torch.softmax(einstein_scores / (d ** (1/2)), dim=3)
einstein_att = torch.einsum("bhql,blhd->bqhd", einstein_scores, v).reshape(b, s, h * d)

assert torch.all(classic_scores == einstein_scores), "Scores doesn't match"
assert torch.all(classic_att == einstein_att), "Attention doesn't match"

The attention scores match perfectly, but the final attention score doesn't match. With my inputs, here is the result:

>>> print(classic_att)
tensor([[[ 1.1246,  0.1376,  1.2368, -0.6316],
         [-2.1842, -0.0181, -2.2082, -0.0023]],

        [[ 0.5911,  0.2132, -0.1727,  0.8552],
         [ 0.2701,  0.0846,  0.2370,  0.1205]]])
>>> print(einstein_att)
tensor([[[ 1.1246,  0.1376, -2.1842, -0.0181],
         [ 1.2368, -0.6316, -2.2082, -0.0023]],

        [[ 0.5911,  0.2132,  0.2701,  0.0846],
         [-0.1727,  0.8552,  0.2370,  0.1205]]])

It seems that the values aren't off, they are just transposed? I'm a newbie with einsum, and I couldn't figure it out. Hope someone can found the solution for this :)

Tinghao-NTU commented 1 year ago

I believe it's a bug, bro. You may check this link to find out the correct method for calculating the selft attention