I have trouble understand why the number of hmb access is O(𝑁^2 𝑑^2 /𝑀) instead of O(𝑁^2 𝑑^2 /𝑀^2)? I thought the length of each of the two for loops is O(𝑁 𝑑 /𝑀).
Also, for standard attention, why the number is O(𝑁 𝑑 + 𝑁^2)? M doesn't play a role here? When we do matrix multiplication, can we still read/write matrices in blocks of size M?
I have trouble understand why the number of hmb access is O(𝑁^2 𝑑^2 /𝑀) instead of O(𝑁^2 𝑑^2 /𝑀^2)? I thought the length of each of the two for loops is O(𝑁 𝑑 /𝑀).
Also, for standard attention, why the number is O(𝑁 𝑑 + 𝑁^2)? M doesn't play a role here? When we do matrix multiplication, can we still read/write matrices in blocks of size M?
I appreciate it if someone can answer.