Open Amanda-Barbara opened 5 months ago
@Amanda-Barbara 学习版,不考虑性能,看懂了直接看官方版就行。triton的和cutlass的在特定的shape下能接近官方实现,因为为了简单起见这里的cutlass版写死了分块的大小,而官方版本会根据数据规模选择最优的分块大小。cuda版没做任何优化,纯属熟悉flash流程。
请问 最优的分块大小 一般要考虑哪些因素?
请问 最优的分块大小 一般要考虑哪些因素?
@vfdff 不太好说,感觉和输入规模,硬件算力,编译版本,驱动版本,smem大小,计算的形状等都有关, 感觉是个申请资源和使用资源的tradeoff。可以看一些别人枚举的例子(下面代码来自FlagAttention)
def get_config(M, D):
if torch.cuda.get_device_capability() == (8, 0):
if D <= 64:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
else:
if M <= 1024:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
else:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
elif torch.cuda.get_device_capability() == (8, 6):
if D <= 64:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
else:
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
else:
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
return (BLOCK_M, BLOCK_N, num_stages, num_warps)
大佬,三个版本各自实现的flash-attention的性能对比结果如何?