66RING / tiny-flash-attention

flash attention tutorial written in python, triton, cuda, cutlass
216 stars 17 forks source link

About kBlockKSmem #7

Open HuyNguyen-hust opened 3 months ago

HuyNguyen-hust commented 3 months ago

Hi @66RING, thank you for your helpful work. I have one question about the use of kBlockKSmem in csrc/kernel_traits.h. When you define SmemLayoutAtomQ:

using SmemLayoutAtomQ = decltype(
        composition(Swizzle<kSwizzle, 3, 3>{},
                    // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
                    Layout<Shape<_8, Int<kBlockKSmem>>,
                           Stride<Int<kBlockKSmem>, _1>>{}));

You note that using kHeadDim instead of kBlockKSmem would lead to wrong results. I don't understand why, can you please explain?

66RING commented 3 months ago

Hi @HuyNguyen-hust. Honestly, I'm not quite sure, this kernel config was directly copied from the official repository.

And I think this is because large smem size will be split into multi "page" and Atom should no exceed one “page”? According to this comment,

    // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
    // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
    // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
    // to the same banks.

And you can print SmemLayoutAtomQ{} and SmemLayoutQ{} and you will found SmemLayoutQ{} is like Sw<3,3,3> o _0 o (_64,(_64,_4)):(_64,(_1,_4096)). Kind of (64, npages)?

HuyNguyen-hust commented 3 months ago

Thank you for your instant response. I think you are right. IMO, it wouldn't lead to any wrong results but unexpected accessing pattern. If we set SmemLayoutAtomQ shape = (8, 128), we will be dealing with a 64-bank width row, where smem is split into 2 pages. The layout is now looking like this (with each element is one uint128_t):

       0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15 
    +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
 0  |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  9 |  8 | 11 | 10 | 13 | 12 | 15 | 14 |
    +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
 1  | 18 | 19 | 16 | 17 | 22 | 23 | 20 | 21 | 27 | 26 | 25 | 24 | 31 | 30 | 29 | 28 |
    +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
 2  | 36 | 37 | 38 | 39 | 32 | 33 | 34 | 35 | 45 | 44 | 47 | 46 | 41 | 40 | 43 | 42 |
    +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
 3  | 54 | 55 | 52 | 53 | 50 | 51 | 48 | 49 | 63 | 62 | 61 | 60 | 59 | 58 | 57 | 56 |
    +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+

So with swizzle or not, thread 0 - 7 will write to the first page (first 7 columns or 32 banks) and thread 8 - 15 will write to the second page (next 7 columns but still same banks)