Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.66k stars 1.25k forks source link

Type of gemm. #703

Open gaodaheng opened 10 months ago

gaodaheng commented 10 months ago

All gemm in flash attention (inlcude forward & backward), input is fp16/bf16 (include left matrax & right matrax), output is fp32?

tridao commented 10 months ago

Yes that's right.

fate08301017 commented 2 months ago

if input is bf16,should output still be fp32 ? for example , when q,k is bf16,can q*kT output bf16 dtype@tridao

tridao commented 2 months ago

yes q@k^T is in fp32, softmax is done in fp32, then converted to bf16 to do the gemm with V.