tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
436 stars 63 forks source link

Enable FP32 Accumulate in Flash Attention and Flash Decode #13364

Open caixunshiren opened 2 weeks ago

caixunshiren commented 2 weeks ago

Description

We do not have support for fp32 accumulate in sdpa family kernels. This becomes a problem when number of chunks gets large and we see diverging pcc from ground truth. For models that requires 128K sequel, this is problematic.

This issue tracks the enabling of fp32 accumulate in the following kernels:

round 1:

round 2:

FYI @cglagovichTT

caixunshiren commented 2 days ago

Update: