qk_scale * tl.max(qk, 1)) uses vmul, but qk * qk_scale - m_ij[:, None] uses vfmavfma has higher precision than vmul+vadd because it only round once.
More specifically, Suppose the vectors all only contain one element, and qk = round(133120.0 * 133120.0) = 17720934400.0, qk_scale = round(0.25 * 1.44269502162933349609) = 0.360673755407333374023
Then
Problem Description
This is due to the core logic of FA algorithm: https://github.com/ROCm/aotriton/blob/f6b28a9b7265b69e3df54ea6ba0237e8a8d6f736/tritonsrc/fwd_kernel_inner.py#L95-L97
qk_scale * tl.max(qk, 1))
usesvmul
, butqk * qk_scale - m_ij[:, None]
usesvfma
vfma
has higher precision thanvmul
+vadd
because it only round once.More specifically, Suppose the vectors all only contain one element, and
qk = round(133120.0 * 133120.0) = 17720934400.0
,qk_scale = round(0.25 * 1.44269502162933349609) = 0.360673755407333374023
ThenTherefore,
p
in the code above yields toexp2(247.375)=inf
. This consequently leads to nan in following steps.Operating System
N/A
CPU
N/A
GPU
MI300X
ROCm Version
ROCm 6.2.3
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response