google-research / long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Apache License 2.0
710 stars 77 forks source link

Computing required attention span #31

Closed yongyi-wu closed 3 years ago

yongyi-wu commented 3 years ago

Hi,

I just have a general clarification question: for the required attention span mentioned in the Long Range Arena paper, do you calculate the distance from the attended keys to only the last query token in a sequence? In other words, the maximum possible distance is always 1K (2K, or 4K, respectively, depending on the tasks). Also, I am wondering how you deal with query tokens that are in the middle of a sequence.

Thank you.

MostafaDehghani commented 3 years ago

Thank you Yongyi for you interest in LRA.

For computing the required attention span, we loop over all queries, compute the distance from every query token to every key, scaled by the weight of that attention (and average over all heads):

So for one head, we do something like this:

for q in queries:
   for k in keys:
       MAS += dist(q, k) *  Att[q, k]

This plot might also make it a bit clear: image

Hope that is the answer to your question?

yongyi-wu commented 3 years ago

Thanks for the clarification. To follow up, suppose the length of a sequence is L. In this case, the maximal possible dist(q, k) for q in the middle is shorter (in particular, the midpoint q can have at most 1/2 L for dist(q, k)). Thus, the overall required attention span upper bounded by 3/4 L rather than L. Is my understanding correct?

MostafaDehghani commented 3 years ago

Yes. This is correct that required attention span is not bounded by L. for a seq length of 5, for example, the max required attention span is:

(4 + 3 + 2 + 3 + 4) / 5 = 3.2

for a seq length of 6, for example, the max required attention span is:

( 5 + 4 + 3 + 3 + 4 + 5) / 6 = 4
yongyi-wu commented 3 years ago

Thanks for great explanations.