Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.43k stars 1.35k forks source link

[Bug] Potential hazard in epilogue when kUseVarSeqLen=true #1344

Closed QiZhangNV closed 1 week ago

QiZhangNV commented 1 week ago

Hello, I'd like to report a potential hazard that occurs in the epilogue when kUseVarSeqLen=true. Under this condition, it utilizes write_tiled() instead of write_tma() to write O to the gmem. This implies that all threads are responsible for issuing a copy from smem to gmem. However, it appears that these threads are not synchronized prior to the copy operation.

To consistently reproduce this bug, insert the following code before the cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO)at https://github.com/Dao-AILab/flash-attention/blob/main/hopper/epilogue_fwd_sm90_tma.hpp#L200. This makes the last warp copy from rmem to smem much slower than the others. Due to the lack of synchronization, the data are copied to gmem before the smem data is fully prepared.

if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) {
    __nanosleep(1*1000);
}
tridao commented 1 week ago

Thanks, I think you're right. We might have fixed it in this branch (decode) which we're working on merging to main https://github.com/Dao-AILab/flash-attention/blob/decode/hopper/epilogue_fwd_sm90_tma.hpp

QiZhangNV commented 1 week ago

That’s great, it looks good to me.