NVIDIA / cudnn-frontend

cudnn_frontend provides a c++ wrapper for the cudnn backend API and samples on how to use it
MIT License
448 stars 90 forks source link

[Question] How to match Flash Attention 2 performance? #98

Open vedantroy opened 2 months ago

vedantroy commented 2 months ago

I wrote a helper that allows someone to use CuDNN attention within Pytorch seamlessly.

import cudnn
import torch
import math

# export CUDNN_FRONTEND_LOG_FLIE=fe.log
# export CUDNN_FRONTEND_LOG_INFO=1

# import os
# os.environ["CUDNN_FRONTEND_LOG_FILE"] = "fe.log"
# os.environ["CUDNN_FRONTEND_LOG_INFO"] = "1"

def convert_to_cudnn_type(torch_type):
    if torch_type == torch.float16:
        return cudnn.data_type.HALF
    elif torch_type == torch.bfloat16:
        return cudnn.data_type.BFLOAT16
    elif torch_type == torch.float32:
        return cudnn.data_type.FLOAT
    elif torch_type == torch.int32:
        return cudnn.data_type.INT32
    elif torch_type == torch.int64:
        return cudnn.data_type.INT64
    else:
        raise ValueError("Unsupported tensor data type.")

def make_cudnn_autograd(*, num_heads, head_dim, dtype):
    assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}"
    dtype = convert_to_cudnn_type(dtype)
    # match CuDNN's docs
    H, D = num_heads, head_dim
    del num_heads, head_dim

    cache = {}

    def init_or_check_tensor_attrs(tensor_name, tensor):
        nonlocal cache
        for attr in ['shape', 'stride', 'dtype', 'device']:
            key = f'{tensor_name}_{attr}'
            if key not in cache:
                cache[key] = getattr(tensor, attr)
                if callable(cache[key]):
                    cache[key] = cache[key]()
            else:
                v = cache[key]() if callable(cache[key]) else cache[key]
                assert cache[key] == v, f"Expected {cache[key]} but got {v}"

    class CuDNNAttention(torch.autograd.Function):
        @staticmethod
        def forward(ctx, B, N, L, q, kv, seqlens_kv):
            assert q.shape == (B, N, H, D)
            assert kv.shape == (B, N + L, 2, H, D)
            assert seqlens_kv.shape == (B,)

            # CuDNN plans are compiled for a specific shape, stride, dtype
            # So we need to verify those attributes
            init_or_check_tensor_attrs('q', q)
            init_or_check_tensor_attrs('kv', kv)
            init_or_check_tensor_attrs('seqlens_kv', seqlens_kv)

            q = q.permute(0, 2, 1, 3)  # B N H D -> B H N D
            kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D
            k_view, v_view = torch.unbind(kv_view, dim=0)

            assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)"
            assert k_view.shape == (B, H, (N + L),  D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L),  D)}"
            assert v_view.shape == (B, H, (N + L), D)

            # TODO: Is this safe?
            if 'stats' not in cache:
                cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device)
                cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1)
                cache['o'] = torch.empty_like(q)

            stats = cache['stats']
            seqlens_q = cache['seqlens_q']
            o = cache['o']

            seqlens_kv = seqlens_kv.view(B, 1, 1, 1)

            if 'compiled_graph_fwd' not in cache:
                print("Compiling CuDNN graphs ...")
                g_fwd = cudnn.pygraph(
                    io_data_type=dtype,
                    intermediate_data_type=cudnn.data_type.FLOAT,
                    compute_data_type=cudnn.data_type.FLOAT,
                )
                cache['name_to_cu_tensor'] = {
                    'q_cu': g_fwd.tensor_like(q.detach()),
                    'k_cu': g_fwd.tensor_like(k_view.detach()),
                    'v_cu': g_fwd.tensor_like(v_view.detach()),
                    'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()),
                    'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach())
                }
                cu_tens = cache['name_to_cu_tensor']

                o_forward, stats_forward = g_fwd.sdpa(
                    name="sdpa",
                    q=cu_tens['q_cu'],
                    k=cu_tens['k_cu'],
                    v=cu_tens['v_cu'],
                    is_inference=False,
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu'],
                    seq_len_kv=cu_tens['seqlens_kv_cu']
                )

                o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype)
                stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride())

                cu_tens['o_forward_cu'] = o_forward
                cu_tens['stats_forward_cu'] = stats_forward

                def assert_cudnn_shape(tensor, expected_shape):
                    assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}"

                assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1))

                g_fwd.validate()
                g_fwd.build_operation_graph()
                g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_fwd.check_support()
                g_fwd.build_plans()

                cache['compiled_graph_fwd'] = g_fwd

                g_bwd = cudnn.pygraph(
                     io_data_type=dtype,
                     intermediate_data_type=cudnn.data_type.FLOAT,
                     compute_data_type=cudnn.data_type.FLOAT,
                )

                cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach())
                cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach())
                cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach())
                cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach())
                cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(o.detach())
                cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach())
                cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach())
                cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach())

                dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward(
                    name="sdpa_backward",
                    q=cu_tens['q_cu_bwd'],
                    k=cu_tens['k_cu_bwd'],
                    v=cu_tens['v_cu_bwd'],
                    o=cu_tens['o_cu_bwd'],
                    dO=cu_tens['dO_cu_bwd'],
                    stats=cu_tens['stats_cu_bwd'],
                    attn_scale=1.0 / math.sqrt(D),
                    use_causal_mask=False,
                    use_padding_mask=True,
                    seq_len_q=cu_tens['seqlens_q_cu_bwd'],
                    seq_len_kv=cu_tens['seqlens_kv_cu_bwd']
                )

                # TODO: Is this safe?
                # cache['dQ'] = torch.empty_like(q).contiguous()
                # cache['dK'] = torch.empty_like(k_view).contiguous()
                # cache['dV'] = torch.empty_like(v_view).contiguous()

                cache['dQ'] = torch.empty_like(q)
                cache['dK'] = torch.empty_like(k_view)
                cache['dV'] = torch.empty_like(v_view)

                dQ_bwd_cu.set_output(True).set_dim(cache['dQ'].size()).set_stride(cache['dQ'].stride())
                dK_bwd_cu.set_output(True).set_dim(cache['dK'].size()).set_stride(cache['dK'].stride())
                dV_bwd_cu.set_output(True).set_dim(cache['dV'].size()).set_stride(cache['dV'].stride())

                cu_tens['dQ_cu_bwd'] = dQ_bwd_cu
                cu_tens['dK_cu_bwd'] = dK_bwd_cu
                cu_tens['dV_cu_bwd'] = dV_bwd_cu

                assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D))
                assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D))
                assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1))
                assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1))
                assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1))

                g_bwd.validate()
                g_bwd.build_operation_graph()
                g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
                g_bwd.check_support()
                g_bwd.build_plans()

                cache['compiled_graph_bwd'] = g_bwd

                # TODO: Is this safe?
                cache['workspace'] = torch.empty(
                    max(g_fwd.get_workspace_size(), g_bwd.get_workspace_size()),
                    device=q.device, dtype=torch.uint8
                )

            name_to_cu_tensor = cache['name_to_cu_tensor']
            variant_pack_forward = {
                name_to_cu_tensor[name]: tensor for name, tensor in [
                    ('q_cu', q),
                    ('k_cu', k_view),
                    ('v_cu', v_view),
                    ('o_forward_cu', o),
                    ('stats_forward_cu', stats),
                    ('seqlens_q_cu', seqlens_q),
                    ('seqlens_kv_cu', seqlens_kv)
                ]
            }
            cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace'])
            ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv)
            ctx.B, ctx.N, ctx.L = B, N, L
            ctx.dtype = dtype
            return o

        @staticmethod
        def backward(ctx, grad_output):
            q, k, v, o, stats, seqlens = ctx.saved_tensors
            B, N, L = ctx.B, ctx.N, ctx.L
            seqlens_q = cache['seqlens_q']

            cu_tens = cache['name_to_cu_tensor']

            assert tuple(grad_output.shape) ==  (B, H, N, D)
            assert tuple(grad_output.shape) == tuple(cu_tens['dO_cu_bwd'].get_dim())
            # For batch size 1, the stride can have 2 1s, I think this is a Pytorch bug
            # https://discuss.pytorch.org/t/stride-has-2-1s-in-it/208036
            assert tuple(grad_output.stride())[1:] == tuple(cu_tens['dO_cu_bwd'].get_stride())[1:], f"{tuple(cu_tens['dO_cu_bwd'].get_stride())} (expected) != {tuple(grad_output.stride())} (actual) for shape {tuple(grad_output.shape)}"
            assert convert_to_cudnn_type(grad_output.dtype) == cu_tens['dO_cu_bwd'].get_data_type()

            variant_pack_backward = {
                cu_tens[name]: tensor for name, tensor in [
                    ('dQ_cu_bwd', cache['dQ']),
                    ('dK_cu_bwd', cache['dK']),
                    ('dV_cu_bwd', cache['dV']),
                    ('q_cu_bwd', q),
                    ('k_cu_bwd', k),
                    ('v_cu_bwd', v),
                    ('o_cu_bwd', o),
                    ('dO_cu_bwd', grad_output),
                    ('stats_cu_bwd', stats),
                    ('seqlens_q_cu_bwd', seqlens_q),
                    ('seqlens_kv_cu_bwd', seqlens)
                ]
            }

            cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace'])
            assert cache['dQ'].shape == (B, H, N, D)
            dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D

            assert cache['dK'].shape == (B, H, N + L, D)
            assert cache['dV'].shape == (B, H, N + L, D)

            dKV = torch.stack([cache['dK'], cache['dV']], dim=2)
            assert dKV.shape == (B, H, 2, N + L, D)

            dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D

            return None, None, None, dQ, dKV, None

    return CuDNNAttention

However, while this gets better forward pass performance. It gets far worse backwards pass performance. Any thoughts on why this might be the case? I'm hoping there might be some obvious deficiency in my code.

(Unit is ms).

attention-forward-performance:
   batch_size     CuDNN  FlashAttention
0         1.0  0.022976        0.033024
1         2.0  0.021664        0.039456
2         4.0  0.047680        0.058112
3         6.0  0.056800        0.072208

attention-backward-performance:
   batch_size     CuDNN  FlashAttention
0         2.0  0.386144        0.282272
1         4.0  0.741664        0.301184
2         6.0  1.108608        0.464320
Anerudhan commented 2 months ago
vedantroy commented 2 months ago

CuDNN version: 9.1.0.

nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000002:00:01.0 Off |                    0 |
| N/A   31C    P0             116W / 700W |    885MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000002:00:02.0 Off |                    0 |
| N/A   30C    P0              72W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000002:00:03.0 Off |                    0 |
| N/A   29C    P0              70W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000002:00:04.0 Off |                    0 |
| N/A   27C    P0              68W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000003:00:01.0 Off |                    0 |
| N/A   28C    P0              68W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000003:00:02.0 Off |                    0 |
| N/A   29C    P0              73W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000003:00:03.0 Off |                    0 |
| N/A   30C    P0              70W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000003:00:04.0 Off |                    0 |
| N/A   27C    P0              69W / 700W |      3MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    239227      C   /tmp/env/bin/python                         876MiB |
+---------------------------------------------------------------------------------------+
vedantroy commented 2 months ago

Archive.zip

I've attached both the benchmarking and CuDNN wrapper code to this post. I suspect the benchmarking code is off, so I'll switch to something simpler (like the Pytorch profiler), and see what the results are.

Anerudhan commented 2 months ago

You can try improvising on this Install FAV2 pip inside the container and go from there. Try out the latest container (24.07, just in case).