Open suluner opened 2 months ago
Hi, I was testing fused attention fp8 in the tutorial on L20 GPU,the test code is as following:
@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("fp8", [True]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, fp8, dtype=torch.float16): torch.manual_seed(20) q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") p = torch.softmax(p.float(), dim=-1).half() # p = torch.exp(p) ref_out = torch.matmul(p, v) if fp8: q = q.to(torch.float8_e5m2) k = k.to(torch.float8_e5m2) v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) v = v.to(torch.float8_e5m2) # triton implementation tri_out = attention(q, k, v, causal, sm_scale).half() print(f"ref_out:\n{ref_out}") print(f"tri_out:\n{tri_out}") # compare assert torch.allclose(ref_out, tri_out, atol=0.5, rtol=0.5)
I found the test can pass on Hopper GPU(eg. H100), but failed on L20 GPU(Ada GPU).
What`s more, if I set q/k dtype to torch.float8_e5m2, v dtype to torch.float16, the test can pass on L20 too.
My question:
Is the all fp8 fused attention only supported on Hopper GPU now? If I want to run it on Ada GPU, what can I do?
@suluner atol=0.5, rtol=0.5 is a pretty high tolerance. Got tests passing on L40s.
atol=0.5, rtol=0.5
Hi, I was testing fused attention fp8 in the tutorial on L20 GPU,the test code is as following:
I found the test can pass on Hopper GPU(eg. H100), but failed on L20 GPU(Ada GPU).
What`s more, if I set q/k dtype to torch.float8_e5m2, v dtype to torch.float16, the test can pass on L20 too.
My question:
Is the all fp8 fused attention only supported on Hopper GPU now? If I want to run it on Ada GPU, what can I do?