Smerity / sha-rnn

Single Headed Attention RNN - "Stop thinking with your head"
1.18k stars 133 forks source link

Efficiency claims on attention module used #15

Open munael opened 3 years ago

munael commented 3 years ago

image

In Figure 1 there's a claim that the attention module is "highly efficient". This's explained by removing the need for K/V transforms. Then for the attention scores block it is said

The A block represents scaled dot product attention, a vector-vector operation

This seems misleading, as the overall complexity of the A block is still a large N^2 Matrix-Matrix product. This's usually the highest complexity section in the classical Attention module.

Can you clarify :D ?

Smerity commented 3 years ago

Apologies for the delayed reply.

You are correct that dot product attention requires N by N dot products to compute the attention.

The claim for attention efficiency for the SHA-RNN is along the lines of Shazeer's One Write-Head is All You Need. Given the keys and values do not require a matrix multiplication there's substantial computational savings with only the queries requiring a matrix multiplication. That's why I note the vector-vector operation.

For reducing the N by N attention component you would indeed need to look towards other potential solutions (approximate attention, sparse attention, ...).