Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.81k stars 1.27k forks source link

Questions about calculating the number of hmb accesses #1260

Open uniqueness opened 1 week ago

uniqueness commented 1 week ago

I have trouble understand why the number of hmb access is O(𝑁^2 𝑑^2 /𝑀) instead of O(𝑁^2 𝑑^2 /𝑀^2)? I thought the length of each of the two for loops is O(𝑁 𝑑 /𝑀).

Also, for standard attention, why the number is O(𝑁 𝑑 + 𝑁^2)? M doesn't play a role here? When we do matrix multiplication, can we still read/write matrices in blocks of size M?

I appreciate it if someone can answer.

tridao commented 1 week ago

The proof is in appendix C of the FlashAttention paper.

uniqueness commented 1 week ago

I saw appendix C but I didn't find relevant details. Can you help?

image Why is it NdT_c instead of T_rT_c?

imageWhy doesn't M play a role here?