NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.92k stars 320 forks source link

TransformerEngine FP8 is slower & more memory intensive than FlashAttention FP16? #1119

Closed darius-lam closed 2 months ago

darius-lam commented 2 months ago

I'm running some benchmarks on TransformerEngine MHA FP8 versus FlashAttention MHA FP16. However, I'm consistently getting that TE FP8 is not only slower by 50-60% than FlashAttention; it also uses much more memory (11GB vs 27GB).

I'm scratching my head because FP8 should use less memory for the same sequence length + MHA parameters. I'm using the latest TE build from source on 1xH100 and cudnn installed. Here's the benchmarking code:

TransformerEngine:

b = 32
n = 1024 #w * h
nhead = 32
nhead_k = 8
head_dim = 128

hidden_dim = 4096
n_iters = 1000

print("BASELINE TEST, %d TOKENS" % n)

#seq = torch.randn((b, n, hidden_dim)).cuda()
seq = torch.randn((n, b, hidden_dim)).cuda()
mha = te.TransformerLayer(hidden_dim, ffn_hidden_size = hidden_dim*4, num_attention_heads = nhead, num_gqa_groups = nhead_k, hidden_dropout=0, attention_dropout=0, kv_channels=head_dim, self_attn_mask_type='no_mask', bias=True).cuda()
params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)#, amax_history_len=32, amax_compute_algo="max")

num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))

start = perf_counter()
for _ in tqdm(range(n_iters)):
    opt.zero_grad()

    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = mha(seq)
        loss = out.sum()

    loss.backward()
    opt.step()

torch.cuda.synchronize()
print('perf: %f' % (1000*(perf_counter() - start)/n_iters))

FlashAttention Benchmark:

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.ops.fused_dense import FusedDense
from flash_attn.modules.mha import MHA

# Baseline flash_attn test

b = 32
n = 1024
nhead = 32
nhead_k = 8
head_dim = 64

hidden_dim = 4096
n_iters = 1000

print("BASELINE TEST, %d TOKENS" % n)

seq = torch.randn((b, n, hidden_dim)).cuda().to(torch.bfloat16)
class TransformerLayer(nn.Module):
    def __init__(self, hidden_dim, nhead, nhead_k, ffn_hidden_size):
        super().__init__()
        self.mha = MHA(hidden_dim, nhead, nhead_k, use_flash_attn=True, fused_bias_fc=True, causal=False)

        self.ffn = nn.Sequential(
            FusedDense(hidden_dim, ffn_hidden_size, bias=True),
            nn.GELU(),
            FusedDense(ffn_hidden_size, hidden_dim, bias=True),
        )

    def forward(self, x):
        x = x + self.mha(x)
        x = x + self.ffn(x)
        return x 

#mha = MHA(hidden_dim, nhead, nhead_k, use_flash_attn=True, fused_bias_fc=True, causal=False).cuda()
mha = TransformerLayer(hidden_dim, nhead, nhead_k, ffn_hidden_size = hidden_dim*4).cuda()

params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)
num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))

a = perf_counter()
for _ in tqdm(range(n_iters)):
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        opt.zero_grad()
        out = mha(seq)
        loss = out.sum()

        loss.backward()
        opt.step()

print('flash_attn_func: %f' % (1000*(perf_counter() - a)/n_iters))

Any ideas?

ptrendx commented 2 months ago

I tested the scrips you posted. The reason why TE is slower in your example is due to the fact that fp8_autocast is only affecting the FP8 execution, but we do preserve the precision of the other parts of the model. In this case it is FP32, since you did not use AMP or cast the model to FP16/BF16 there. In the default recipe in TE attention is not being computed in FP8, and so without casting the model it uses the original precision (FP32). Neither flash attention nor cuDNN attention backends in TE support FP32 execution, and so what you got was the slowest, unfused attention (which additionally uses the most memory), which resulted in the poor performance you observed. Changing your TE script to this:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from time import perf_counter
from tqdm import tqdm

b = 32
n = 1024 #w * h
nhead = 32
nhead_k = 8
head_dim = 128

hidden_dim = 4096
n_iters = 1000

print("BASELINE TEST, %d TOKENS" % n)

#seq = torch.randn((b, n, hidden_dim)).cuda()
seq = torch.randn((n, b, hidden_dim)).cuda().bfloat16()
mha = te.TransformerLayer(hidden_dim, ffn_hidden_size = hidden_dim*4, num_attention_heads = nhead, num_gqa_groups = nhead_k, hidden_dropout=0, attention_dropout=0, kv_channels=head_dim, self_attn_mask_type='no_mask', bias=True).cuda()
mha = mha.bfloat16()
params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)#, amax_history_len=32, amax_compute_algo="max")

num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))

start = perf_counter()
for _ in tqdm(range(n_iters)):
    opt.zero_grad()

    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = mha(seq)
        loss = out.sum()

    loss.backward()
    opt.step()

torch.cuda.synchronize()
print('perf: %f' % (1000*(perf_counter() - start)/n_iters))

(I added the missing imports and cast the model and input to bf16) I get following performance results on H100 PCIe (just fwd/bwd - omitting the optimizer, which gives the constant overhead for all cases):

One additional note is that in your FA script you did not include LayerNorm, but that is a small difference in this case.

darius-lam commented 2 months ago

Very helpful, thank you

darius-lam commented 2 months ago

hi @ptrendx I have a follow-up question: what is actually cast to fp8 with TE using the code above? First, we cast model to bfloat16, so model weights are 16-bit. Then, we use the fp8_autocast to convert the TE Linear layer activations to FP8. Is that correct? So the activations & gradients for the TE linear layers are in FP8, but the weights, the optimizer states are all bfloat 16?

On the other hand, if we don't directly cast to bfloat16, the model weights are in fp32 and stay that way throughout training? What happens if we use torch.cuda.amp.autocast(Bfloat16)?

GEMM Weight Gradient Optimizer State Activation Comm Gradient Comm
Model BFloat16 / AMP + Cast TF8 fp8 bf16 fp32 (?) fp32 (?) bf16 (?) bf16 (?)
Model No BFloat16 + Cast TF8 fp8 fp32 fp8 (?) fp32 (?) fp32 (?) fp32 (?) fp32 (?)
Model No BFloat16 + No Cast TF8 fp32 fp32 fp32 fp32 fp32 fp32

I am trying to wrap my head around how the FP8 is implemented with TE

ptrendx commented 2 months ago

Functionality-wise, you can think of fp8_autocast as changing just the internal execution of the operators, so a functionally equivalent execution of the forward pass would be:

x = x.to(fp8).to(fp32)
weight = weight.to(fp8).to(fp32)
y = linear(x, weight).to(original input type)
y = y + bias

What is different in actual execution is that the linear layer actually takes the x and weight in fp8 and only the internal accumulator is in FP32, but the result is the same. So in your table:

We also have an additional context manager fp8_model_init which makes the layers actually hold FP8 parameters rather than higher precision ones. It is not the default behavior though since the user needs to make sure that they do have a high precision copy of the parameters somewhere else (or that those parameters are not actually trainable, like in inference or Lora), otherwise the convergence would suffer.