Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.35k stars 1.21k forks source link

[Feature request] attn_mask support #119

Open junjie18 opened 1 year ago

junjie18 commented 1 year ago

Hi, thanks for your great work. Would you be willing to support attn_mask in flash attention. Since Query Denoise[1, 2] seems to be a common practice in Computer Vision tasks.

[1] Li F, Zhang H, Liu S, et al. Dn-detr: Accelerate detr training by introducing query denoising [2] Zhang H, Li F, Liu S, et al. Dino: Detr with improved denoising anchor boxes for end-to-end object detection

tridao commented 1 year ago

The Triton implementation in this repo supports attention bias. However it's an experimental feature, as I sometimes see race conditions (due to the Triton compiler) in the backward pass with attention bias. The Triton team has just rewritten their backend so things might be more stable, but I haven't tried.

vadimkantorov commented 1 year ago

@tridao also perf-wise, how does Triton impl compare to your custom cutlass impl? are there any recent evidence on the backward issues solved in more recent tritons?

tridao commented 1 year ago

Speed should be around the same. Idk about recent triton versions, I haven't had the time to test.