triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.2k stars 1.62k forks source link

Fused Attention FP8 correctness on L20(Ada GPU) #4476

Open suluner opened 2 months ago

suluner commented 2 months ago

Hi, I was testing fused attention fp8 in the tutorial on L20 GPU,the test code is as following:

@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("fp8", [True])
def test_op(Z, H, N_CTX, HEAD_DIM, causal, fp8, dtype=torch.float16):
    torch.manual_seed(20)
    q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
    k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
    v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
    sm_scale = 0.5
    dout = torch.randn_like(q)
    # reference implementation
    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
    if causal:
        p[:, :, M == 0] = float("-inf")
    p = torch.softmax(p.float(), dim=-1).half()
    # p = torch.exp(p)
    ref_out = torch.matmul(p, v)
    if fp8:
        q = q.to(torch.float8_e5m2)
        k = k.to(torch.float8_e5m2)
        v = v.permute(0, 1, 3, 2).contiguous()
        v = v.permute(0, 1, 3, 2)
        v = v.to(torch.float8_e5m2)
    # triton implementation
    tri_out = attention(q, k, v, causal, sm_scale).half()

    print(f"ref_out:\n{ref_out}")
    print(f"tri_out:\n{tri_out}")

    # compare
    assert torch.allclose(ref_out, tri_out, atol=0.5, rtol=0.5)

I found the test can pass on Hopper GPU(eg. H100), but failed on L20 GPU(Ada GPU).

What`s more, if I set q/k dtype to torch.float8_e5m2, v dtype to torch.float16, the test can pass on L20 too.

My question:

Is the all fp8 fused attention only supported on Hopper GPU now? If I want to run it on Ada GPU, what can I do?

michaelfeil commented 2 months ago

@suluner atol=0.5, rtol=0.5 is a pretty high tolerance. Got tests passing on L40s.