Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.72k stars 1.26k forks source link

Six Flash-Attention-3 unit tests fail on H20 #1272

Open cailun01 opened 11 hours ago

cailun01 commented 11 hours ago

Hi, I installed Flash-Attention-3 from source and run test_flash_attn.py. I found that 5946 UT pass and 6 UT fail. Would you please help me solve these problems?

My env:

Failed UT:

==================================================== short test summary info =====================================================
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mha-dtype1] - AssertionError: assert 140.0 <= ((2 * 0.0234375) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mqa-dtype0] - AssertionError: assert 118.875 <= ((2 * 0.001953125) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mha-dtype1] - AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mqa-dtype1] - AssertionError: assert 119.0 <= ((2 * 0.015625) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05)
6 failed, 5946 passed in 72.27s (0:01:12)

Full log:

           window_size=window_size,
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
            upcast=False,
            reorder_ops=True,
        )

        # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
        # m = qk.amax(-1, keepdim=True)
        # s_tmp = torch.exp((qk - m) / math.sqrt(d))
        # exp_sum = s_tmp.sum(-1)
        # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
        # lse_ref = torch.logsumexp(qk, dim=-1)

        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
        print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

        # if not causal:
        #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
        # breakpoint()

        if d <= 128 and dtype != torch.float8_e4m3fn:
            g = torch.randn_like(out)
            do_o = (g.float() * out.float()).sum(-1)
            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
            print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
            print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
            print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
            print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
            print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
            print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
            print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
            print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
            print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
            print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
            print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
            print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

        # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
        # P = torch.softmax(qk, -1)
        # dP = P * (dS - do_o.unsqueeze(1))
        # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
        # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
        # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
        # breakpoint()

        # Check that FlashAttention's numerical error is at most twice the numerical error
        # of a Pytorch implementation.
        # breakpoint()
        if(dtype != torch.float8_e4m3fn):
            assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
        else:
            # just test correctness of fp8 kernel w/o further quantization techniques
            assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()

        if d <= 128 and dtype != torch.float8_e4m3fn:
>           assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
E           AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05)
E            +  where 204.0 = <built-in method item of Tensor object at 0x7f01ba249630>()
E            +    where <built-in method item of Tensor object at 0x7f01ba249630> = tensor(204., device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(204., device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba27a540>()
E            +        where <built-in method max of Tensor object at 0x7f01ba27a540> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 9.7656e-04,  ..., 0.0000e+00,\n           7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 9.7656e-04,  ..., 0.0000e+00,\n           7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba27a680>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba27a680> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...79e-02,  7.4219e-02,  ..., -4.0527e-02,\n           -4.1016e-02,  5.5664e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...79e-02,  7.3242e-02,  ..., -4.0527e-02,\n           -4.0283e-02,  5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs
E            +  and   0.015625 = <built-in method item of Tensor object at 0x7f01ba249220>()
E            +    where <built-in method item of Tensor object at 0x7f01ba249220> = tensor(0.0156, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(0.0156, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba249270>()
E            +        where <built-in method max of Tensor object at 0x7f01ba249270> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 9.7656e-04,  ..., 0.0000e+00,\n           2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 9.7656e-04,  ..., 0.0000e+00,\n           2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba249310>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba249310> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...67e-02,  7.2266e-02,  ..., -4.0527e-02,\n           -4.0527e-02,  5.6641e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...79e-02,  7.3242e-02,  ..., -4.0527e-02,\n           -4.0283e-02,  5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs

test_flash_attn.py:195: AssertionError
_______________________________ test_flash_attn_output[640-128-1.0-128-True-False-True-mha-dtype1] _______________________________

seqlen_q = 640, seqlen_k = 128, d = 128, causal = True, local = False, deterministic = True, mha_type = 'mha'
dtype = torch.bfloat16, descale = 1.0

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    # @pytest.mark.parametrize("dtype", [torch.float16])
    # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
    @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
    # @pytest.mark.parametrize("mha_type", ["mha"])
    @pytest.mark.parametrize("causal", [False, True])
    # @pytest.mark.parametrize("causal", [True])
    @pytest.mark.parametrize("local", [False, True])
    # @pytest.mark.parametrize("local", [True])
    @pytest.mark.parametrize("deterministic", [False, True])
    # @pytest.mark.parametrize("deterministic", [True])
    # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
    # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
    # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
    # @pytest.mark.parametrize('d', [56, 80])
    # @pytest.mark.parametrize("d", [64, 128, 256])
    # @pytest.mark.parametrize("d", [64, 96, 128])
    # @pytest.mark.parametrize("d", [256])
    @pytest.mark.parametrize("d", [64, 128, 256])
    @pytest.mark.parametrize("descale", [1.0])
    # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
    @pytest.mark.parametrize(
        "seqlen_q,seqlen_k",
        [
            (1, 1),
            # (257, 1),
            (64, 128),
            (128, 128),
            (256, 256),
            (113, 203),
            (128, 217),
            (113, 211),
            (108, 256),
            (256, 512),
            (384, 256),
            (640, 128),
            (512, 256),
            (1024, 1024),
            (1023, 1024),
            (1024, 1023),
            (4096, 4096),
        ],
    )
    # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
    def test_flash_attn_output(
        seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale
    ):
        device = "cuda"
        if(dtype == torch.float8_e4m3fn):
            dtype_init = torch.float16
        else:
            dtype_init = dtype
        print(dtype)
        # set seed
        torch.random.manual_seed(0)
        # batch_size = 40
        # nheads = 16
        batch_size = 4
        nheads = 6
        nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
        # nheads_kv = 2
        # batch_size = 9
        # nheads = 6
        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
        k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)

        q = q.to(dtype)
        k = k.to(dtype)
        v = v.to(dtype)

        softmax_scale = q.shape[-1] ** (-0.5)
        descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
        if(dtype != torch.float8_e4m3fn):
            out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic)
        else:
            out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
                q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
            )

        q = q.to(dtype_init)
        k = k.to(dtype_init)
        v = v.to(dtype_init)

        if(dtype == torch.float8_e4m3fn):
            descale_q = descale_q.to(dtype_init)
            descale_k = descale_k.to(dtype_init)
            descale_v = descale_v.to(dtype_init)
            q = q * descale_q
            k = k * descale_k
            v = v * descale_v

        out_ref, attn_ref = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
            upcast=False,
            reorder_ops=True,
        )

        # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
        # m = qk.amax(-1, keepdim=True)
        # s_tmp = torch.exp((qk - m) / math.sqrt(d))
        # exp_sum = s_tmp.sum(-1)
        # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
        # lse_ref = torch.logsumexp(qk, dim=-1)

        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
        print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

        # if not causal:
        #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
        # breakpoint()

        if d <= 128 and dtype != torch.float8_e4m3fn:
            g = torch.randn_like(out)
            do_o = (g.float() * out.float()).sum(-1)
            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
            print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
            print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
            print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
            print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
            print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
            print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
            print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
            print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
            print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
            print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
            print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
            print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

        # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
        # P = torch.softmax(qk, -1)
        # dP = P * (dS - do_o.unsqueeze(1))
        # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
        # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
        # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
        # breakpoint()

        # Check that FlashAttention's numerical error is at most twice the numerical error
        # of a Pytorch implementation.
        # breakpoint()
        if(dtype != torch.float8_e4m3fn):
            assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
        else:
            # just test correctness of fp8 kernel w/o further quantization techniques
            assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()

        if d <= 128 and dtype != torch.float8_e4m3fn:
>           assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
E           AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05)
E            +  where 118.5 = <built-in method item of Tensor object at 0x7f01ba296db0>()
E            +    where <built-in method item of Tensor object at 0x7f01ba296db0> = tensor(118.5000, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(118.5000, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba2968b0>()
E            +        where <built-in method max of Tensor object at 0x7f01ba2968b0> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 1.9531e-03,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 1.9531e-03,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296b80>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba296b80> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...33e-02,  2.8516e-01,  ...,  5.9326e-02,\n            1.6406e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...45e-02,  2.8320e-01,  ...,  5.8838e-02,\n            1.6309e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs
E            +  and   0.0234375 = <built-in method item of Tensor object at 0x7f01ba277f90>()
E            +    where <built-in method item of Tensor object at 0x7f01ba277f90> = tensor(0.0234, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(0.0234, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba277310>()
E            +        where <built-in method max of Tensor object at 0x7f01ba277310> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 0.0000e+00,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 0.0000e+00,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296a90>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba296a90> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...45e-02,  2.8320e-01,  ...,  5.8350e-02,\n            1.6211e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+(113, 211),
            (108, 256),
            (256, 512),
            (384, 256),
            (640, 128),
            (512, 256),
            (1024, 1024),
            (1023, 1024),
            (1024, 1023),
            (4096, 4096),
        ],
    )
    # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
    def test_flash_attn_output(
        seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale
    ):
        device = "cuda"
        if(dtype == torch.float8_e4m3fn):
            dtype_init = torch.float16
        else:
            dtype_init = dtype
        print(dtype)
        # set seed
        torch.random.manual_seed(0)
        # batch_size = 40
        # nheads = 16
        batch_size = 4
        nheads = 6
        nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
        # nheads_kv = 2
        # batch_size = 9
        # nheads = 6
        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
        k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)

        q = q.to(dtype)
        k = k.to(dtype)
        v = v.to(dtype)

        softmax_scale = q.shape[-1] ** (-0.5)
        descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
        if(dtype != torch.float8_e4m3fn):
            out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic)
        else:
            out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
                q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
            )

        q = q.to(dtype_init)
        k = k.to(dtype_init)
        v = v.to(dtype_init)

        if(dtype == torch.float8_e4m3fn):
            descale_q = descale_q.to(dtype_init)
            descale_k = descale_k.to(dtype_init)
            descale_v = descale_v.to(dtype_init)
            q = q * descale_q
            k = k * descale_k
            v = v * descale_v

        out_ref, attn_ref = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
            upcast=False,
            reorder_ops=True,
        )

        # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
        # m = qk.amax(-1, keepdim=True)
        # s_tmp = torch.exp((qk - m) / math.sqrt(d))
        # exp_sum = s_tmp.sum(-1)
        # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
        # lse_ref = torch.logsumexp(qk, dim=-1)

        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
        print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

        # if not causal:
        #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
        # breakpoint()

        if d <= 128 and dtype != torch.float8_e4m3fn:
            g = torch.randn_like(out)
            do_o = (g.float() * out.float()).sum(-1)
            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
            print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
            print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
            print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
            print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
            print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
            print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
            print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
            print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
            print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
            print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
            print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
            print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

        # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
        # P = torch.softmax(qk, -1)
        # dP = P * (dS - do_o.unsqueeze(1))
        # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
        # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
        # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
        # breakpoint()

        # Check that FlashAttention's numerical error is at most twice the numerical error
        # of a Pytorch implementation.
        # breakpoint()
        if(dtype != torch.float8_e4m3fn):
            assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
        else:
            # just test correctness of fp8 kernel w/o further quantization techniques
            assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()

        if d <= 128 and dtype != torch.float8_e4m3fn:
>           assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
E           AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05)
E            +  where 118.5 = <built-in method item of Tensor object at 0x7f01ba296db0>()
E            +    where <built-in method item of Tensor object at 0x7f01ba296db0> = tensor(118.5000, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(118.5000, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba2968b0>()
E            +        where <built-in method max of Tensor object at 0x7f01ba2968b0> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 1.9531e-03,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 1.9531e-03,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296b80>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba296b80> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...33e-02,  2.8516e-01,  ...,  5.9326e-02,\n            1.6406e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...45e-02,  2.8320e-01,  ...,  5.8838e-02,\n            1.6309e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs
E            +  and   0.0234375 = <built-in method item of Tensor object at 0x7f01ba277f90>()
E            +    where <built-in method item of Tensor object at 0x7f01ba277f90> = tensor(0.0234, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(0.0234, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba277310>()
E            +        where <built-in method max of Tensor object at 0x7f01ba277310> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 0.0000e+00,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 0.0000e+00,  ..., 4.8828e-04,\n           9.7656e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296a90>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba296a90> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...45e-02,  2.8320e-01,  ...,  5.8350e-02,\n            1.6211e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...45e-02,  2.8320e-01,  ...,  5.8838e-02,\n            1.6309e-01,  3.5156e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs

test_flash_attn.py:195: AssertionError
_______________________________ test_flash_attn_output[640-128-1.0-128-True-False-True-mqa-dtype1] _______________________________

seqlen_q = 640, seqlen_k = 128, d = 128, causal = True, local = False, deterministic = True, mha_type = 'mqa'
dtype = torch.bfloat16, descale = 1.0

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    # @pytest.mark.parametrize("dtype", [torch.float16])
    # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
    @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
    # @pytest.mark.parametrize("mha_type", ["mha"])
    @pytest.mark.parametrize("causal", [False, True])
    # @pytest.mark.parametrize("causal", [True])
    @pytest.mark.parametrize("local", [False, True])
    # @pytest.mark.parametrize("local", [True])
    @pytest.mark.parametrize("deterministic", [False, True])
    # @pytest.mark.parametrize("deterministic", [True])
    # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
    # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
    # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
    # @pytest.mark.parametrize('d', [56, 80])
    # @pytest.mark.parametrize("d", [64, 128, 256])
    # @pytest.mark.parametrize("d", [64, 96, 128])
    # @pytest.mark.parametrize("d", [256])
    @pytest.mark.parametrize("d", [64, 128, 256])
    @pytest.mark.parametrize("descale", [1.0])
    # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
    @pytest.mark.parametrize(
        "seqlen_q,seqlen_k",
        [
            (1, 1),
            # (257, 1),
            (64, 128),
            (128, 128),
            (256, 256),
            (113, 203),
            (128, 217),
            (113, 211),
            (108, 256),
            (256, 512),
            (384, 256),
            (640, 128),
            (512, 256),
            (1024, 1024),
            (1023, 1024),
            (1024, 1023),
            (4096, 4096),
        ],
    )
    # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
    def test_flash_attn_output(
        seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale
    ):
        device = "cuda"
        if(dtype == torch.float8_e4m3fn):
            dtype_init = torch.float16
        else:
            dtype_init = dtype
        print(dtype)
        # set seed
        torch.random.manual_seed(0)
        # batch_size = 40
        # nheads = 16
        batch_size = 4
        nheads = 6
        nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
        # nheads_kv = 2
        # batch_size = 9
        # nheads = 6
        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
        k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)

        q = q.to(dtype)
        k = k.to(dtype)
        v = v.to(dtype)

        softmax_scale = q.shape[-1] ** (-0.5)
        descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
        if(dtype != torch.float8_e4m3fn):
            out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic)
        else:
            out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
                q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
            )

        q = q.to(dtype_init)
        k = k.to(dtype_init)
        v = v.to(dtype_init)

        if(dtype == torch.float8_e4m3fn):
            descale_q = descale_q.to(dtype_init)
            descale_k = descale_k.to(dtype_init)
            descale_v = descale_v.to(dtype_init)
            q = q * descale_q
            k = k * descale_k
            v = v * descale_v

        out_ref, attn_ref = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
            upcast=False,
            reorder_ops=True,
        )

        # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
        # m = qk.amax(-1, keepdim=True)
        # s_tmp = torch.exp((qk - m) / math.sqrt(d))
        # exp_sum = s_tmp.sum(-1)
        # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
        # lse_ref = torch.logsumexp(qk, dim=-1)

        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
        print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

        # if not causal:
        #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
        # breakpoint()

        if d <= 128 and dtype != torch.float8_e4m3fn:
            g = torch.randn_like(out)
            do_o = (g.float() * out.float()).sum(-1)
            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
            print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
            print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
            print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
            print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
            print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
            print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
            print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
            print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
            print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
            print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
            print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
            print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

        # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
        # P = torch.softmax(qk, -1)
        # dP = P * (dS - do_o.unsqueeze(1))
        # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
        # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
        # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
        # breakpoint()

        # Check that FlashAttention's numerical error is at most twice the numerical error
        # of a Pytorch implementation.
        # breakpoint()
        if(dtype != torch.float8_e4m3fn):
            assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
        else:
            # just test correctness of fp8 kernel w/o further quantization techniques
            assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()

        if d <= 128 and dtype != torch.float8_e4m3fn:
>           assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
E           AssertionError: assert 119.0 <= ((2 * 0.015625) + 3e-05)
E            +  where 119.0 = <built-in method item of Tensor object at 0x7f01ba28b950>()
E            +    where <built-in method item of Tensor object at 0x7f01ba28b950> = tensor(119., device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(119., device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba28bc20>()
E            +        where <built-in method max of Tensor object at 0x7f01ba28bc20> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 2.4414e-04,  ..., 0.0000e+00,\n           2.4414e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 2.4414e-04,  ..., 0.0000e+00,\n           2.4414e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba28bbd0>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba28bbd0> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...82e-01, -2.2705e-02,  ..., -2.2559e-01,\n           -4.4189e-02,  1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...33e-01, -2.2949e-02,  ..., -2.2559e-01,\n           -4.3945e-02,  1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs
E            +  and   0.015625 = <built-in method item of Tensor object at 0x7f01ba28bb80>()
E            +    where <built-in method item of Tensor object at 0x7f01ba28bb80> = tensor(0.0156, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(0.0156, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba28b860>()
E            +        where <built-in method max of Tensor object at 0x7f01ba28b860> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 8.5449e-04,  ..., 9.7656e-04,\n           4.8828e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 8.5449e-04,  ..., 9.7656e-04,\n           4.8828e-04, 0.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba28bae0>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba28bae0> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...82e-01, -2.3804e-02,  ..., -2.2656e-01,\n           -4.4434e-02,  1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...33e-01, -2.2949e-02,  ..., -2.2559e-01,\n           -4.3945e-02,  1.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16)).abs

test_flash_attn.py:195: AssertionError
_______________________________ test_flash_attn_output[640-128-1.0-128-True-False-True-gqa-dtype1] _______________________________

seqlen_q = 640, seqlen_k = 128, d = 128, causal = True, local = False, deterministic = True, mha_type = 'gqa'
dtype = torch.bfloat16, descale = 1.0

    @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
    # @pytest.mark.parametrize("dtype", [torch.float16])
    # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
    @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
    # @pytest.mark.parametrize("mha_type", ["mha"])
    @pytest.mark.parametrize("causal", [False, True])
    # @pytest.mark.parametrize("causal", [True])
    @pytest.mark.parametrize("local", [False, True])
    # @pytest.mark.parametrize("local", [True])
    @pytest.mark.parametrize("deterministic", [False, True])
    # @pytest.mark.parametrize("deterministic", [True])
    # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
    # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
    # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
    # @pytest.mark.parametrize('d', [56, 80])
    # @pytest.mark.parametrize("d", [64, 128, 256])
    # @pytest.mark.parametrize("d", [64, 96, 128])
    # @pytest.mark.parametrize("d", [256])
    @pytest.mark.parametrize("d", [64, 128, 256])
    @pytest.mark.parametrize("descale", [1.0])
    # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
    @pytest.mark.parametrize(
        "seqlen_q,seqlen_k",
        [
            (1, 1),
            # (257, 1),
            (64, 128),
            (128, 128),
            (256, 256),
            (113, 203),
            (128, 217),
            (113, 211),
            (108, 256),
            (256, 512),
            (384, 256),
            (640, 128),
            (512, 256),
            (1024, 1024),
            (1023, 1024),
            (1024, 1023),
            (4096, 4096),
        ],
    )
    # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
    def test_flash_attn_output(
        seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale
    ):
        device = "cuda"
        if(dtype == torch.float8_e4m3fn):
            dtype_init = torch.float16
        else:
            dtype_init = dtype
        print(dtype)
        # set seed
        torch.random.manual_seed(0)
        # batch_size = 40
        # nheads = 16
        batch_size = 4
        nheads = 6
        nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
        # nheads_kv = 2
        # batch_size = 9
        # nheads = 6
        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
        k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
        v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)

        q = q.to(dtype)
        k = k.to(dtype)
        v = v.to(dtype)

        softmax_scale = q.shape[-1] ** (-0.5)
        descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
        descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
        if(dtype != torch.float8_e4m3fn):
            out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic)
        else:
            out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
                q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
            )

        q = q.to(dtype_init)
        k = k.to(dtype_init)
        v = v.to(dtype_init)

        if(dtype == torch.float8_e4m3fn):
            descale_q = descale_q.to(dtype_init)
            descale_k = descale_k.to(dtype_init)
            descale_v = descale_v.to(dtype_init)
            q = q * descale_q
            k = k * descale_k
            v = v * descale_v

        out_ref, attn_ref = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
        )
        out_pt, attn_pt = attention_ref(
            q,
            k,
            v,
            None,
            None,
            causal=causal,
            window_size=window_size,
            upcast=False,
            reorder_ops=True,
        )

        # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
        # m = qk.amax(-1, keepdim=True)
        # s_tmp = torch.exp((qk - m) / math.sqrt(d))
        # exp_sum = s_tmp.sum(-1)
        # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
        # lse_ref = torch.logsumexp(qk, dim=-1)

        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
        print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")

        # if not causal:
        #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
        # breakpoint()

        if d <= 128 and dtype != torch.float8_e4m3fn:
            g = torch.randn_like(out)
            do_o = (g.float() * out.float()).sum(-1)
            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
            print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
            print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
            print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
            print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
            print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
            print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
            print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
            print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
            print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
            print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
            print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
            print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

        # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
        # P = torch.softmax(qk, -1)
        # dP = P * (dS - do_o.unsqueeze(1))
        # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
        # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
        # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
        # breakpoint()

        # Check that FlashAttention's numerical error is at most twice the numerical error
        # of a Pytorch implementation.
        # breakpoint()
        if(dtype != torch.float8_e4m3fn):
            assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
        else:
            # just test correctness of fp8 kernel w/o further quantization techniques
            assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()

        if d <= 128 and dtype != torch.float8_e4m3fn:
>           assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
E           AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05)
E            +  where 204.0 = <built-in method item of Tensor object at 0x7f01ba296860>()
E            +    where <built-in method item of Tensor object at 0x7f01ba296860> = tensor(204., device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(204., device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba296d60>()
E            +        where <built-in method max of Tensor object at 0x7f01ba296d60> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 9.7656e-04,  ..., 0.0000e+00,\n           7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...0.0000e+00, 9.7656e-04,  ..., 0.0000e+00,\n           7.3242e-04, 2.4414e-04]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296a40>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba296a40> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...79e-02,  7.4219e-02,  ..., -4.0527e-02,\n           -4.1016e-02,  5.5664e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...79e-02,  7.3242e-02,  ..., -4.0527e-02,\n           -4.0283e-02,  5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs
E            +  and   0.015625 = <built-in method item of Tensor object at 0x7f01ba296e00>()
E            +    where <built-in method item of Tensor object at 0x7f01ba296e00> = tensor(0.0156, device='cuda:0', dtype=torch.bfloat16).item
E            +      where tensor(0.0156, device='cuda:0', dtype=torch.bfloat16) = <built-in method max of Tensor object at 0x7f01ba296ea0>()
E            +        where <built-in method max of Tensor object at 0x7f01ba296ea0> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 9.7656e-04,  ..., 0.0000e+00,\n           2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16).max
E            +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...4.8828e-04, 9.7656e-04,  ..., 0.0000e+00,\n           2.4414e-04, 1.2207e-03]]]], device='cuda:0', dtype=torch.bfloat16) = <built-in method abs of Tensor object at 0x7f01ba296db0>()
E            +            where <built-in method abs of Tensor object at 0x7f01ba296db0> = (tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...67e-02,  7.2266e-02,  ..., -4.0527e-02,\n           -4.0527e-02,  5.6641e-02]]]], device='cuda:0', dtype=torch.bfloat16) - tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n            0.0000e+00,  0.0000e+00],\n          [...79e-02,  7.3242e-02,  ..., -4.0527e-02,\n           -4.0283e-02,  5.5420e-02]]]], device='cuda:0', dtype=torch.bfloat16)).abs

test_flash_attn.py:195: AssertionError
==================================================== short test summary info =====================================================
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mha-dtype1] - AssertionError: assert 140.0 <= ((2 * 0.0234375) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-mqa-dtype0] - AssertionError: assert 118.875 <= ((2 * 0.001953125) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-False-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mha-dtype1] - AssertionError: assert 118.5 <= ((2 * 0.0234375) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-mqa-dtype1] - AssertionError: assert 119.0 <= ((2 * 0.015625) + 3e-05)
FAILED test_flash_attn.py::test_flash_attn_output[640-128-1.0-128-True-False-True-gqa-dtype1] - AssertionError: assert 204.0 <= ((2 * 0.015625) + 3e-05)
6 failed, 5946 passed in 72.27s (0:01:12)
tridao commented 2 hours ago

Thanks for the bug report. Can you try this branch: https://github.com/Dao-AILab/flash-attention/tree/tdd It's a newer backward implementation that we're testing.