tspeterkim / flash-attention-minimal

Flash Attention in ~100 lines of CUDA (forward pass only)
Apache License 2.0
548 stars 48 forks source link

slow in for loop test #3

Closed DefTruth closed 5 months ago

DefTruth commented 5 months ago

slow if i test it in for loop:

REPEAT = 10
manual_result = manual_attn(q, k, v) # warmup
st = time.time()
for _ in range(REPEAT):
    manual_result = manual_attn(q, k, v)
    torch.cuda.synchronize()
print(f"manual attention mean time(ms): {((time.time() - st) * 1000) / REPEAT}")

minimal_result = minimal_attn.forward(q, k, v)  # warmup
st = time.time()
for _ in range(REPEAT):
    minimal_result = minimal_attn.forward(q, k, v)
    torch.cuda.synchronize()
print(f"minimal attention mean time(ms): {((time.time() - st) * 1000) / REPEAT}")
DefTruth commented 5 months ago
=== profiling manual attention ===
manual attention mean time(ms): 0.12569427490234375
=== profiling minimal attention ===
minimal attention mean time(ms): 1.2495994567871094
attn values sanity check: True
DefTruth commented 5 months ago

the custom flash forward kernel seems not catch py autograd.profiler:

                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
--------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::to         0.41%      16.000us        23.72%     921.000us     460.500us      17.000us         1.26%     893.000us     446.500us             2
            aten::_to_copy         0.93%      36.000us        23.15%     899.000us     449.500us      22.000us         1.63%     876.000us     438.000us             2
       aten::empty_strided         1.11%      43.000us        13.08%     508.000us     169.333us     747.000us        55.50%     747.000us     249.000us             3
               aten::zero_         7.91%     307.000us         9.55%     371.000us     185.500us     359.000us        26.67%     420.000us     210.000us             2
               aten::zeros         0.44%      17.000us         8.32%     323.000us     323.000us       8.000us         0.59%     364.000us     364.000us             1
               aten::copy_         0.49%      19.000us         8.86%     344.000us     172.000us     109.000us         8.10%     109.000us      54.500us             2
          aten::zeros_like         0.54%      21.000us         3.94%     153.000us     153.000us       9.000us         0.67%      79.000us      79.000us             1
               aten::fill_         0.88%      34.000us         1.83%      71.000us      35.500us      62.000us         4.61%      62.000us      31.000us             2
                aten::full         0.44%      17.000us         1.18%      46.000us      46.000us       7.000us         0.52%      10.000us      10.000us             1
          aten::empty_like         0.28%      11.000us         0.95%      37.000us      37.000us       3.000us         0.22%       5.000us       5.000us             1
--------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

all the kernels listed are kernels from aten core ops.