ROCm / aotriton

Ahead of Time (AOT) Triton Math Library
MIT License
42 stars 15 forks source link

[Documentation]: large bf16 inputs leads to nan #54

Open xinyazhang opened 3 weeks ago

xinyazhang commented 3 weeks ago

Problem Description

This is due to the core logic of FA algorithm: https://github.com/ROCm/aotriton/blob/f6b28a9b7265b69e3df54ea6ba0237e8a8d6f736/tritonsrc/fwd_kernel_inner.py#L95-L97

qk_scale * tl.max(qk, 1)) uses vmul, but qk * qk_scale - m_ij[:, None] uses vfma vfma has higher precision than vmul+vadd because it only round once.

More specifically, Suppose the vectors all only contain one element, and qk = round(133120.0 * 133120.0) = 17720934400.0, qk_scale = round(0.25 * 1.44269502162933349609) = 0.360673755407333374023 Then

qk * qk_scale - m_ij
= qk * qk_scale - qk_scale * qk
= fma(qk, qk_scale, -vmul(qk, qk_scale))
= fma(qk, qk_scale, -round(17720934400.0, 0.360673755407333374023))
= round(qk * qk_scale - 6391475712.0)
= round(17720934400.0 * 0.360673755407333374023 - round(17720934400.0, 0.360673755407333374023))
= round(6391475959.3749999999922470912 - 6391475712.0)
= 247.375

Therefore, p in the code above yields to exp2(247.375)=inf. This consequently leads to nan in following steps.

Operating System

N/A

CPU

N/A

GPU

MI300X

ROCm Version

ROCm 6.2.3

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response