Open cailun01 opened 11 hours ago
Hi, I installed Flash-Attention-3 from source and run test_flash_attn.py. I found that 5946 UT pass and 6 UT fail. Would you please help me solve these problems?
My env:
Failed UT:
==================================================== short test summary info ===================================================== FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mha-dtype1] - AssertionError: assert 140.0 <= ((2 * 0.0234375) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mqa-dtype0] - AssertionError: assert 118.875 <= ((2 * 0.001953125) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mha-dtype1] - AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mqa-dtype1] - AssertionError: assert 119.0 <= ((2 * 0.015625) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05) 6 failed, 5946 passed in 72.27s (0:01:12)
Full log:
window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # qk = torch.einsum('bshd,bthd->bhst', q, k).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) # lse_ref = torch.logsumexp(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() if d <= 128 and dtype != torch.float8_e4m3fn: > assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5 E AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05) E + where 204.0 = <built-in method item of Tensor object at 0x7f01ba249630>() E + where <built-in method item of Tensor object at 0x7f01ba249630> = tensor(204., device='cuda:0', dtype=torch.bfloat16).item E + where tensor(204., device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba27a540>() E + where <built-in method max of Tensor object at 0x7f01ba27a540> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 9.7656e-04, ..., 0.0000e+00,\n 7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 9.7656e-04, ..., 0.0000e+00,\n 7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba27a680>() E + where <built-in method abs of Tensor object at 0x7f01ba27a680> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...79e-02, 7.4219e-02, ..., -4.0527e-02,\n -4.1016e-02, 5.5664e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...79e-02, 7.3242e-02, ..., -4.0527e-02,\n -4.0283e-02, 5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs E + and 0.015625 = <built-in method item of Tensor object at 0x7f01ba249220>() E + where <built-in method item of Tensor object at 0x7f01ba249220> = tensor(0.0156, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(0.0156, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba249270>() E + where <built-in method max of Tensor object at 0x7f01ba249270> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 9.7656e-04, ..., 0.0000e+00,\n 2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 9.7656e-04, ..., 0.0000e+00,\n 2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba249310>() E + where <built-in method abs of Tensor object at 0x7f01ba249310> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...67e-02, 7.2266e-02, ..., -4.0527e-02,\n -4.0527e-02, 5.6641e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...79e-02, 7.3242e-02, ..., -4.0527e-02,\n -4.0283e-02, 5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs test_flash_attn.py:195: AssertionError _______________________________ test_flash_attn_output[640-128-1.0-128-True-False-True-mha-dtype1] _______________________________ seqlen_q = 640, seqlen_k = 128, d = 128, causal = True, local = False, deterministic = True, mha_type = 'mha' dtype = torch.bfloat16, descale = 1.0 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("descale", [1.0]) # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), # (257, 1), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale ): device = "cuda" if(dtype == torch.float8_e4m3fn): dtype_init = torch.float16 else: dtype_init = dtype print(dtype) # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # nheads_kv = 2 # batch_size = 9 # nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) q = q.to(dtype) k = k.to(dtype) v = v.to(dtype) softmax_scale = q.shape[-1] ** (-0.5) descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') if(dtype != torch.float8_e4m3fn): out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) else: out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) q = q.to(dtype_init) k = k.to(dtype_init) v = v.to(dtype_init) if(dtype == torch.float8_e4m3fn): descale_q = descale_q.to(dtype_init) descale_k = descale_k.to(dtype_init) descale_v = descale_v.to(dtype_init) q = q * descale_q k = k * descale_k v = v * descale_v out_ref, attn_ref = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # qk = torch.einsum('bshd,bthd->bhst', q, k).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) # lse_ref = torch.logsumexp(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() if d <= 128 and dtype != torch.float8_e4m3fn: > assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5 E AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05) E + where 118.5 = <built-in method item of Tensor object at 0x7f01ba296db0>() E + where <built-in method item of Tensor object at 0x7f01ba296db0> = tensor(118.5000, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(118.5000, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba2968b0>() E + where <built-in method max of Tensor object at 0x7f01ba2968b0> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 1.9531e-03, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 1.9531e-03, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296b80>() E + where <built-in method abs of Tensor object at 0x7f01ba296b80> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...33e-02, 2.8516e-01, ..., 5.9326e-02,\n 1.6406e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...45e-02, 2.8320e-01, ..., 5.8838e-02,\n 1.6309e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs E + and 0.0234375 = <built-in method item of Tensor object at 0x7f01ba277f90>() E + where <built-in method item of Tensor object at 0x7f01ba277f90> = tensor(0.0234, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(0.0234, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba277310>() E + where <built-in method max of Tensor object at 0x7f01ba277310> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 0.0000e+00, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 0.0000e+00, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296a90>() E + where <built-in method abs of Tensor object at 0x7f01ba296a90> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...45e-02, 2.8320e-01, ..., 5.8350e-02,\n 1.6211e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+(113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale ): device = "cuda" if(dtype == torch.float8_e4m3fn): dtype_init = torch.float16 else: dtype_init = dtype print(dtype) # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # nheads_kv = 2 # batch_size = 9 # nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) q = q.to(dtype) k = k.to(dtype) v = v.to(dtype) softmax_scale = q.shape[-1] ** (-0.5) descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') if(dtype != torch.float8_e4m3fn): out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) else: out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) q = q.to(dtype_init) k = k.to(dtype_init) v = v.to(dtype_init) if(dtype == torch.float8_e4m3fn): descale_q = descale_q.to(dtype_init) descale_k = descale_k.to(dtype_init) descale_v = descale_v.to(dtype_init) q = q * descale_q k = k * descale_k v = v * descale_v out_ref, attn_ref = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # qk = torch.einsum('bshd,bthd->bhst', q, k).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) # lse_ref = torch.logsumexp(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() if d <= 128 and dtype != torch.float8_e4m3fn: > assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5 E AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05) E + where 118.5 = <built-in method item of Tensor object at 0x7f01ba296db0>() E + where <built-in method item of Tensor object at 0x7f01ba296db0> = tensor(118.5000, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(118.5000, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba2968b0>() E + where <built-in method max of Tensor object at 0x7f01ba2968b0> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 1.9531e-03, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 1.9531e-03, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296b80>() E + where <built-in method abs of Tensor object at 0x7f01ba296b80> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...33e-02, 2.8516e-01, ..., 5.9326e-02,\n 1.6406e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...45e-02, 2.8320e-01, ..., 5.8838e-02,\n 1.6309e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs E + and 0.0234375 = <built-in method item of Tensor object at 0x7f01ba277f90>() E + where <built-in method item of Tensor object at 0x7f01ba277f90> = tensor(0.0234, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(0.0234, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba277310>() E + where <built-in method max of Tensor object at 0x7f01ba277310> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 0.0000e+00, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 0.0000e+00, ..., 4.8828e-04,\n 9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296a90>() E + where <built-in method abs of Tensor object at 0x7f01ba296a90> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...45e-02, 2.8320e-01, ..., 5.8350e-02,\n 1.6211e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...45e-02, 2.8320e-01, ..., 5.8838e-02,\n 1.6309e-01, 3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs test_flash_attn.py:195: AssertionError _______________________________ test_flash_attn_output[640-128-1.0-128-True-False-True-mqa-dtype1] _______________________________ seqlen_q = 640, seqlen_k = 128, d = 128, causal = True, local = False, deterministic = True, mha_type = 'mqa' dtype = torch.bfloat16, descale = 1.0 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("descale", [1.0]) # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), # (257, 1), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale ): device = "cuda" if(dtype == torch.float8_e4m3fn): dtype_init = torch.float16 else: dtype_init = dtype print(dtype) # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # nheads_kv = 2 # batch_size = 9 # nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) q = q.to(dtype) k = k.to(dtype) v = v.to(dtype) softmax_scale = q.shape[-1] ** (-0.5) descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') if(dtype != torch.float8_e4m3fn): out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) else: out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) q = q.to(dtype_init) k = k.to(dtype_init) v = v.to(dtype_init) if(dtype == torch.float8_e4m3fn): descale_q = descale_q.to(dtype_init) descale_k = descale_k.to(dtype_init) descale_v = descale_v.to(dtype_init) q = q * descale_q k = k * descale_k v = v * descale_v out_ref, attn_ref = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # qk = torch.einsum('bshd,bthd->bhst', q, k).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) # lse_ref = torch.logsumexp(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() if d <= 128 and dtype != torch.float8_e4m3fn: > assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5 E AssertionError: assert 119.0 <= ((2 * 0.015625) + 3e-05) E + where 119.0 = <built-in method item of Tensor object at 0x7f01ba28b950>() E + where <built-in method item of Tensor object at 0x7f01ba28b950> = tensor(119., device='cuda:0', dtype=torch.bfloat16).item E + where tensor(119., device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba28bc20>() E + where <built-in method max of Tensor object at 0x7f01ba28bc20> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 2.4414e-04, ..., 0.0000e+00,\n 2.4414e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 2.4414e-04, ..., 0.0000e+00,\n 2.4414e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba28bbd0>() E + where <built-in method abs of Tensor object at 0x7f01ba28bbd0> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...82e-01, -2.2705e-02, ..., -2.2559e-01,\n -4.4189e-02, 1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...33e-01, -2.2949e-02, ..., -2.2559e-01,\n -4.3945e-02, 1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs E + and 0.015625 = <built-in method item of Tensor object at 0x7f01ba28bb80>() E + where <built-in method item of Tensor object at 0x7f01ba28bb80> = tensor(0.0156, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(0.0156, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba28b860>() E + where <built-in method max of Tensor object at 0x7f01ba28b860> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 8.5449e-04, ..., 9.7656e-04,\n 4.8828e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 8.5449e-04, ..., 9.7656e-04,\n 4.8828e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba28bae0>() E + where <built-in method abs of Tensor object at 0x7f01ba28bae0> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...82e-01, -2.3804e-02, ..., -2.2656e-01,\n -4.4434e-02, 1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...33e-01, -2.2949e-02, ..., -2.2559e-01,\n -4.3945e-02, 1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs test_flash_attn.py:195: AssertionError _______________________________ test_flash_attn_output[640-128-1.0-128-True-False-True-gqa-dtype1] _______________________________ seqlen_q = 640, seqlen_k = 128, d = 128, causal = True, local = False, deterministic = True, mha_type = 'gqa' dtype = torch.bfloat16, descale = 1.0 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize("d", [64, 128, 256]) @pytest.mark.parametrize("descale", [1.0]) # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), # (257, 1), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale ): device = "cuda" if(dtype == torch.float8_e4m3fn): dtype_init = torch.float16 else: dtype_init = dtype print(dtype) # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 4 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # nheads_kv = 2 # batch_size = 9 # nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True) q = q.to(dtype) k = k.to(dtype) v = v.to(dtype) softmax_scale = q.shape[-1] ** (-0.5) descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda') descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda') if(dtype != torch.float8_e4m3fn): out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic) else: out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) q = q.to(dtype_init) k = k.to(dtype_init) v = v.to(dtype_init) if(dtype == torch.float8_e4m3fn): descale_q = descale_q.to(dtype_init) descale_k = descale_k.to(dtype_init) descale_v = descale_v.to(dtype_init) q = q * descale_q k = k * descale_k v = v * descale_v out_ref, attn_ref = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # qk = torch.einsum('bshd,bthd->bhst', q, k).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) # lse_ref = torch.logsumexp(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() if d <= 128 and dtype != torch.float8_e4m3fn: g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. # breakpoint() if(dtype != torch.float8_e4m3fn): assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5 else: # just test correctness of fp8 kernel w/o further quantization techniques assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item() if d <= 128 and dtype != torch.float8_e4m3fn: > assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5 E AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05) E + where 204.0 = <built-in method item of Tensor object at 0x7f01ba296860>() E + where <built-in method item of Tensor object at 0x7f01ba296860> = tensor(204., device='cuda:0', dtype=torch.bfloat16).item E + where tensor(204., device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba296d60>() E + where <built-in method max of Tensor object at 0x7f01ba296d60> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 9.7656e-04, ..., 0.0000e+00,\n 7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...0.0000e+00, 9.7656e-04, ..., 0.0000e+00,\n 7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296a40>() E + where <built-in method abs of Tensor object at 0x7f01ba296a40> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...79e-02, 7.4219e-02, ..., -4.0527e-02,\n -4.1016e-02, 5.5664e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...79e-02, 7.3242e-02, ..., -4.0527e-02,\n -4.0283e-02, 5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs E + and 0.015625 = <built-in method item of Tensor object at 0x7f01ba296e00>() E + where <built-in method item of Tensor object at 0x7f01ba296e00> = tensor(0.0156, device='cuda:0', dtype=torch.bfloat16).item E + where tensor(0.0156, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba296ea0>() E + where <built-in method max of Tensor object at 0x7f01ba296ea0> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 9.7656e-04, ..., 0.0000e+00,\n 2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16).max E + where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [0.0000...4.8828e-04, 9.7656e-04, ..., 0.0000e+00,\n 2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296db0>() E + where <built-in method abs of Tensor object at 0x7f01ba296db0> = (tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...67e-02, 7.2266e-02, ..., -4.0527e-02,\n -4.0527e-02, 5.6641e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n 0.0000e+00, 0.0000e+00],\n [...79e-02, 7.3242e-02, ..., -4.0527e-02,\n -4.0283e-02, 5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs test_flash_attn.py:195: AssertionError ==================================================== short test summary info ===================================================== FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mha-dtype1] - AssertionError: assert 140.0 <= ((2 * 0.0234375) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mqa-dtype0] - AssertionError: assert 118.875 <= ((2 * 0.001953125) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mha-dtype1] - AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mqa-dtype1] - AssertionError: assert 119.0 <= ((2 * 0.015625) + 3e-05) FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05) 6 failed, 5946 passed in 72.27s (0:01:12)
Thanks for the bug report. Can you try this branch: https://github.com/Dao-AILab/flash-attention/tree/tdd It's a newer backward implementation that we're testing.
Hi, I installed Flash-Attention-3 from source and run test_flash_attn.py. I found that 5946 UT pass and 6 UT fail. Would you please help me solve these problems?
My env:
Failed UT:
Full log: