66RING / tiny-flash-attention

flash attention tutorial written in python, triton, cuda, cutlass
217 stars 18 forks source link

三个版本的性能对比结果如何? #4

Open Amanda-Barbara opened 5 months ago

Amanda-Barbara commented 5 months ago

大佬,三个版本各自实现的flash-attention的性能对比结果如何?

66RING commented 5 months ago

@Amanda-Barbara 学习版,不考虑性能,看懂了直接看官方版就行。triton的和cutlass的在特定的shape下能接近官方实现,因为为了简单起见这里的cutlass版写死了分块的大小,而官方版本会根据数据规模选择最优的分块大小。cuda版没做任何优化,纯属熟悉flash流程。

vfdff commented 1 month ago

请问 最优的分块大小 一般要考虑哪些因素?

66RING commented 1 month ago

请问 最优的分块大小 一般要考虑哪些因素?

@vfdff 不太好说,感觉和输入规模,硬件算力,编译版本,驱动版本,smem大小,计算的形状等都有关, 感觉是个申请资源和使用资源的tradeoff。可以看一些别人枚举的例子(下面代码来自FlagAttention)

def get_config(M, D):
    if torch.cuda.get_device_capability() == (8, 0):
        if D <= 64:
            BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
        else:
            if M <= 1024:
                BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
            else:
                BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
    elif torch.cuda.get_device_capability() == (8, 6):
        if D <= 64:
            BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
        else:
            BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
    else:
        BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
    return (BLOCK_M, BLOCK_N, num_stages, num_warps)