Closed darius-lam closed 2 months ago
I tested the scrips you posted. The reason why TE is slower in your example is due to the fact that fp8_autocast
is only affecting the FP8 execution, but we do preserve the precision of the other parts of the model. In this case it is FP32, since you did not use AMP or cast the model to FP16/BF16 there. In the default recipe in TE attention is not being computed in FP8, and so without casting the model it uses the original precision (FP32). Neither flash attention nor cuDNN attention backends in TE support FP32 execution, and so what you got was the slowest, unfused attention (which additionally uses the most memory), which resulted in the poor performance you observed.
Changing your TE script to this:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from time import perf_counter
from tqdm import tqdm
b = 32
n = 1024 #w * h
nhead = 32
nhead_k = 8
head_dim = 128
hidden_dim = 4096
n_iters = 1000
print("BASELINE TEST, %d TOKENS" % n)
#seq = torch.randn((b, n, hidden_dim)).cuda()
seq = torch.randn((n, b, hidden_dim)).cuda().bfloat16()
mha = te.TransformerLayer(hidden_dim, ffn_hidden_size = hidden_dim*4, num_attention_heads = nhead, num_gqa_groups = nhead_k, hidden_dropout=0, attention_dropout=0, kv_channels=head_dim, self_attn_mask_type='no_mask', bias=True).cuda()
mha = mha.bfloat16()
params = list(mha.parameters())
opt = torch.optim.AdamW(params, lr=1e-4)
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)#, amax_history_len=32, amax_compute_algo="max")
num_params = sum(p.numel() for p in mha.parameters())
print(f"Number of parameters: %f M" % (num_params/1e6))
start = perf_counter()
for _ in tqdm(range(n_iters)):
opt.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(seq)
loss = out.sum()
loss.backward()
opt.step()
torch.cuda.synchronize()
print('perf: %f' % (1000*(perf_counter() - start)/n_iters))
(I added the missing imports and cast the model and input to bf16) I get following performance results on H100 PCIe (just fwd/bwd - omitting the optimizer, which gives the constant overhead for all cases):
One additional note is that in your FA script you did not include LayerNorm, but that is a small difference in this case.
Very helpful, thank you
hi @ptrendx I have a follow-up question: what is actually cast to fp8 with TE using the code above? First, we cast model to bfloat16, so model weights are 16-bit. Then, we use the fp8_autocast to convert the TE Linear layer activations to FP8. Is that correct? So the activations & gradients for the TE linear layers are in FP8, but the weights, the optimizer states are all bfloat 16?
On the other hand, if we don't directly cast to bfloat16, the model weights are in fp32 and stay that way throughout training? What happens if we use torch.cuda.amp.autocast(Bfloat16)?
GEMM | Weight | Gradient | Optimizer State | Activation Comm | Gradient Comm | ||
---|---|---|---|---|---|---|---|
Model BFloat16 / AMP + Cast TF8 | fp8 | bf16 | fp32 (?) | fp32 (?) | bf16 (?) | bf16 (?) | |
Model No BFloat16 + Cast TF8 | fp8 | fp32 | fp8 (?) | fp32 (?) | fp32 (?) | fp32 (?) | fp32 (?) |
Model No BFloat16 + No Cast TF8 | fp32 | fp32 | fp32 | fp32 | fp32 | fp32 |
I am trying to wrap my head around how the FP8 is implemented with TE
Functionality-wise, you can think of fp8_autocast
as changing just the internal execution of the operators, so a functionally equivalent execution of the forward pass would be:
x = x.to(fp8).to(fp32)
weight = weight.to(fp8).to(fp32)
y = linear(x, weight).to(original input type)
y = y + bias
What is different in actual execution is that the linear layer actually takes the x
and weight
in fp8 and only the internal accumulator is in FP32, but the result is the same.
So in your table:
fp8_autocast
just in time but outputs BF16/FP32fp8_autocast
just in time but outputs BF16We also have an additional context manager fp8_model_init
which makes the layers actually hold FP8 parameters rather than higher precision ones. It is not the default behavior though since the user needs to make sure that they do have a high precision copy of the parameters somewhere else (or that those parameters are not actually trainable, like in inference or Lora), otherwise the convergence would suffer.
I'm running some benchmarks on TransformerEngine MHA FP8 versus FlashAttention MHA FP16. However, I'm consistently getting that TE FP8 is not only slower by 50-60% than FlashAttention; it also uses much more memory (11GB vs 27GB).
I'm scratching my head because FP8 should use less memory for the same sequence length + MHA parameters. I'm using the latest TE build from source on 1xH100 and cudnn installed. Here's the benchmarking code:
TransformerEngine:
FlashAttention Benchmark:
Any ideas?