Open gaodaheng opened 10 months ago
Yes that's right.
if input is bf16,should output still be fp32 ? for example , when q,k is bf16,can q*kT output bf16 dtype@tridao
yes q@k^T is in fp32, softmax is done in fp32, then converted to bf16 to do the gemm with V.
All gemm in flash attention (inlcude forward & backward), input is fp16/bf16 (include left matrax & right matrax), output is fp32?