tspeterkim / flash-attention-minimal

Flash Attention in ~100 lines of CUDA (forward pass only)
Apache License 2.0
548 stars 48 forks source link

Implement backward pass #2

Open leloykun opened 5 months ago

leloykun commented 5 months ago

Description

This PR implements a minimal backward pass for flash attention.

I got these results on my RTX 2060

=== profiling manual attention (backward pass) ===
...
Self CPU time total: 11.139ms
Self CUDA time total: 1.721ms
=== profiling minimal flash attention (backward pass) === 
...
Self CPU time total: 31.466ms
Self CUDA time total: 629.000us

2x speedup

Tho my GPU can only handle size 16 blocks (vs. size 32 blocks for T4)

hypertseng commented 5 months ago

@leloykun hello Franz! I have some trouble with the code and flash attention. Firstly, why the attn values sanity check return False when the seq_len is lower than 32. It lead to collapse in inference which seq_len is usually 1, I guess the block size may cause this result? Then, how to choose a appropriate block size? Looking forward to your reply! image

leloykun commented 4 months ago

Hi @hypertseng!

I believe it was because we weren't exiting the loops after going past the seq length. The forward pass should be fixed in my repo here: https://github.com/leloykun/flash-hyperbolic-attention-minimal

hypertseng commented 4 months ago

@leloykun Recently, I found the flash_attn_bwd implementation in your repo is lower than the manual implementation, this is totally because the implicitly function call of cudaDeviceSynchronize which Increases the CPU time a lot. Do you have any idea to solve this problem? image By the way, I found that change the AtomicAdd to normal add will decrease the cudaDeviceSynchronize occupancy, but I don't know why, I am a beginner of cuda hhhhh.

2440020096 commented 1 month ago

@hypertseng Most likely, cudaDeviceSynchronize time includes the kernel execution time. You can Use cuda events to time it instead.

torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
minimal_result = minimal_attn.forward(q, k, v)
end_event.record()
torch.cuda.synchronize()

elapsed_time_ms = start_event.elapsed_time(end_event)
max_vram_MB = torch.cuda.max_memory_allocated() / (1024*1024)