Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.41k stars 1.35k forks source link

In non-casual case why we have mask? #1256

Open ziyuhuang123 opened 1 month ago

ziyuhuang123 commented 1 month ago
                if constexpr (!Is_causal) {  // Just masking based on col
                    if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { 
                        tSrS(i) = -INFINITY; 
                        printf("seqlen_k - n_block * kBlockN=%d- %d* %d = %d\n", seqlen_k, n_block, kBlockN, seqlen_k - n_block * kBlockN);
                        }
                }

I noticed if seq_lenk=seq_lenq=256, we will have : seqlen_k - n_block kBlockN=256- 1 176 = 80 It is even not the size of a block (like, if we have block size = 128, but seq_len = 127, of course we want to set loc 128 to be -inf). So why we have mask here?

tridao commented 1 month ago

This just says for the 2nd block (columns 176 -> 351), we keep the first 80 columns (176 -> 255) and the rest of the columns are masked out as infinity.

ziyuhuang123 commented 1 month ago

Oh, I understand now. Because we use a strange block_size: 128176 (MN), so even I use 256 as seq_len, in 2nd block we will have many useless calculation! But...why we use 176 as kBlockN?? I noticed this is fixed. Maybe... it uses as much smem as possible? (Well, hard to imagine how it utilizes WGMMA's size....)