66RING / tiny-flash-attention

flash attention tutorial written in python, triton, cuda, cutlass
159 stars 11 forks source link

请教一下tile_to_shape这个函数如何和swizzle配合使用的 #6

Open Ddd195 opened 2 months ago

Ddd195 commented 2 months ago

RT,tile_to_shape这个函数的作用是什么

66RING commented 2 months ago

@Ddd195 根据我的理解,大概意思就是你已经有了一个global memory的tile,现在要在smem上也做tile, 那这个tile的shape要和gmem的shape贴合又要和swizzle atom的shape贴合,所以就有了这么一个tile_to_shape来自动创建shape: 给一个gmem的shape和一个swizzle atom的shape就自动创建一个两边都贴合的shape.

我个人理解是这样,也就只能比较高层次的理解了,具体细节也不很清楚。

Ddd195 commented 1 month ago

@66RING 谢谢您的回复!我请教了一下reed,他说仅仅用于把tile扩展到更大的块,后来我知道了使用方法大概就是先定义经过swizzle的atom,然后用这个atom进行tile_to_shape来构造一个存放在share的tensor,虽然具体细节我也不太懂。

Ddd195 commented 1 month ago

@66RING 请问一下size(Kernel_traits::kNThreads)这个是128吗,为什么在写blockdim时不直接写Kernel_traits::kNThreads呢

66RING commented 1 month ago

@Ddd195 是应该直接写成Kernel_traits::kNThreads的,这个可能是standalone用来调试时改着改着忘了

Ddd195 commented 1 month ago

@66RING 谢谢谢谢!,还想再请教一下tiled_mma.get_slice(tidx);这个函数,我读了您的代码,以Q为例,gQ块是6464,mma能力是6416,block线程数是128,这个gQ块是如何通过get_slice和线程id以mma能力为基础分配给每个线程的,在这个例子中一个线程处理怎么样shape的矩阵呢。

66RING commented 1 month ago

@Ddd195 这些就是看cutlass抽象出来的mma原语了,可以用这个脚本可视化,https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15

比如:

  {
    auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
                                    Layout<Shape<_1,_1, _1>>{},  // AtomLayoutMNK
                                    Layout<Shape<_1,_2, _1>>{}   // ValLayoutMNK
    );
    print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
  }

然后改AtomLayoutMNK和ValLayoutMNK可以微调线程的处理的任务

Ddd195 commented 1 month ago

@66RING 谢谢您的答复,我还想再请教一下retile_S和partition_S有什么区别,没太明白retile的作用。 比如auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA);

66RING commented 1 month ago

@Ddd195 partition_S/D是创建copy对象所需的src和dst。作用在gmem, smem这些"内存"上。而retile是作用在寄存器的,retile来变成寄存器私有数据所需的形状。而partition_fragment_A是构造一个空的寄存器的view,数据需要从smem拷贝进去

比如从smem拷贝到寄存器,那就用partition构造一个对接smem的src,然后用retile构造一个和寄存器对接的dst。

反之同理

  Tensor taccOrO = smem_thr_copy_O.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
  Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)

  // NOTE: 先拷贝到smem
  cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);

我的理解是这样,如果有错麻烦帮我指正一下

Ddd195 commented 1 month ago

@Ddd195 partition_S/D是创建copy对象所需的src和dst。作用在gmem, smem这些"内存"上。而retile是作用在寄存器的,retile来变成寄存器私有数据所需的形状。而partition_fragment_A是构造一个空的寄存器的view,数据需要从smem拷贝进去

比如从smem拷贝到寄存器,那就用partition构造一个对接smem的src,然后用retile构造一个和寄存器对接的dst。

反之同理

  Tensor taccOrO = smem_thr_copy_O.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
  Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)

  // NOTE: 先拷贝到smem
  cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);

我的理解是这样,如果有错麻烦帮我指正一下

我也认为是这样,但是使用起来灵活,感觉我还是无法完全搞明白其中的线程和寄存器如何变换,只能知道这些步骤在做什么,放弃了。谢谢大佬

sleepwalker2017 commented 5 days ago

@Ddd195 这些就是看cutlass抽象出来的mma原语了,可以用这个脚本可视化,https://gist.github.com/66RING/2e188b73fdf703e9f9dfc7371814dd15

比如:

  {
    auto tiled_mma = make_tiled_mma(SM80_16x8x16_F32F16F16F32_TN{},
                                    Layout<Shape<_1,_1, _1>>{},  // AtomLayoutMNK
                                    Layout<Shape<_1,_2, _1>>{}   // ValLayoutMNK
    );
    print_mma_content("flash2: SM80_16x8x16_F32F16F16F32_TN", tiled_mma);
  }

然后改AtomLayoutMNK和ValLayoutMNK可以微调线程的处理的任务

大佬请教下,这个图里的各种颜色是啥意思啊?为什么A的同一列是一个颜色呢?

image