Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.36k stars 1.21k forks source link

Loop unrolling in gemm function: do ldmatrix.sync and mma.sync operations overlap? #1170

Closed phantaurus closed 2 weeks ago

phantaurus commented 2 weeks ago

Hello! Could you help me understand the loop unrolling logic in the gemm function, that both cute::copy and cute::gemm use synchronized instructions (ldmatrix.sync for SmemTiledCopy and mma.sync for TiledMMA) defined in Kernel traits, but by unrolling the loop, which performs cute::copy(,,0) first, then iterating over the k dimension, calling cute::copy(,,i+1) and then cute::gemm(,,i), my guess is that this would make copy and mma overlap? But will they actually overlap given they are both synchronized operations?

https://github.com/Dao-AILab/flash-attention/blob/32792d37ec66902e5d82e149971daacbee8b55d7/csrc/flash_attn/src/utils.h#L137

Thank you so much!

tridao commented 2 weeks ago

In this particular case it's just about general CUDA programming, nothing specific about mma or copy. You can read more about how unrolling is generally helpful in CUDA.

phantaurus commented 2 weeks ago

Thank you so much for your reply! I guess there are some lower-level optimizations that GPU can do, e.g., something like pre-fetching and branch prediction.

phantaurus commented 2 weeks ago

Thank you so much for your reply! I guess there are some lower-level optimizations that GPU can do, e.g., something like pre-fetching and branch prediction.

phantaurus commented 2 weeks ago

Thank you so much for your reply! I guess there are some lower-level optimizations that GPU can do, e.g., something like pre-fetching and branch prediction.