facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

How do you understand the phrase "This can change the outcome of softmax attention: if we merge two tokens with the same key, that key has less effect in the softmax term" in Part Tracking Token Size ? #31

Closed Nanuion closed 1 year ago

dbolya commented 1 year ago

Say you had two keys that were very similar: k1 and k2, and one key that was different: k3. In attention, you dot each query with every key and then softmax over the keys, like this:

softmax([q1.dot(k1), q1.dot(k2), q1.dot(k3)])

Softmax normalize the resulting vector, so the output for q1, k1 will be:

exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k2)) + exp(q1.dot(k3)))

If we instead merged k1 and k2 without duplicating the resulting token (with k1 and k2 similar so the merged token is essentially just k1), this would be akin to

exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k3)))

which is a different value -- the normalization is wrong.

That's why we have to pretend that there's two identical tokens:

exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

which is closer to the original expression.

Nanuion commented 1 year ago

Wow, thank you so much for your reply, I understand!

ZechengLi19 commented 1 year ago

Say you had two keys that were very similar: k1 and k2, and one key that was different: k3. In attention, you dot each query with every key and then softmax over the keys, like this:

softmax([q1.dot(k1), q1.dot(k2), q1.dot(k3)])

Softmax normalize the resulting vector, so the output for q1, k1 will be:

exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k2)) + exp(q1.dot(k3)))

If we instead merged k1 and k2 without duplicating the resulting token (with k1 and k2 similar so the merged token is essentially just k1), this would be akin to

exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k3)))

which is a different value -- the normalization is wrong.

That's why we have to pretend that there's two identical tokens:

exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

which is closer to the original expression.

There seems to be something wrong with this answer?

For k1, it will be

2 * exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

However, for k3, it will be

exp(q1.dot(k3)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

Is it correct?

dbolya commented 1 year ago

There seems to be something wrong with this answer?

For k1, it will be

2 * exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

However, for k3, it will be

exp(q1.dot(k3)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

Is it correct?

Ah yes, my bad. Your expression is correct. So it both takes care of the normalization properly as well as it acts like k1 is actually 2 different tokens so it's effects are doubled once multiplied with V.

ZechengLi19 commented 1 year ago

There seems to be something wrong with this answer? For k1, it will be

2 * exp(q1.dot(k1)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

However, for k3, it will be

exp(q1.dot(k3)) / (exp(q1.dot(k1)) + exp(q1.dot(k1)) + exp(q1.dot(k3)))

Is it correct?

Ah yes, my bad. Your expression is correct. So it both takes care of the normalization properly as well as it acts like k1 is actually 2 different tokens so it's effects are doubled once multiplied with V.

I got it! 😊

Thanks for your quick reply and great work!