NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

Pure bfloat16 vs. mixed precision bfloat16: what's recommended? #950

Closed jasonkrone closed 1 week ago

jasonkrone commented 1 week ago

Hi there,

Given a fixed compute budget, what's recommended:

  1. pure bfloat16 training, where all computations done in bfloat16

  2. bfloat16 MixedPrecision training, where the optimizer maintains a copy of the network weights in fp32 that are updated with optimizer.step()

    • this appears to be the setup used for most of the LLM examples in Megatron-LM, which was used to train the newly released Nemotron-4 model. I presume MixedPrecision was used for Nemotron-4; however, the technical report did not explicitly clarify if it was pure bfloat16 or MixedPrecision bfloat16.

P.S. Appreciate all your work on TransformerEngine

timmoon10 commented 1 week ago

Typically the optimizer needs fp32 weights to ensure good convergence (although see the fp8 optimizer in https://arxiv.org/pdf/2310.18313, which uses per-tensor scaling to get away with fp16 weights). For data-parallel bf16 training, there's a memory-performance tradeoff that depends on your use-case:

  1. fp32 weights, compute in bf16: performance overhead from casting to bf16
  2. fp32 weights, compute in fp32: probably better convergence, but slow performance and extra memory usage from storing activations in fp32
  3. both bf16 and fp32 weights, compute in bf16: can amortize casts over multiple gradient accumulation steps, but extra memory usage and code complexity from storing two copies of the weights

I'd recommend option 1 at first since it just requires running with torch.autocast. Option 3 may be better (as you mention, it's what's used in Megatron-LM), but it depends on many factors and ultimately there's no alternative to benchmarking.

Option 3 is more attractive with FSDP since the fp32 weights can be sharded, while the gathered weights can be in bf16. This should just be a matter of configuring FSDP with torch.distributed.fsdp.MixedPrecision. I want to say it's sufficient to initialize the model in fp32 and the MixedPrecision with param_dtype=torch.bfloat16 and reduce_dtype=torch.float32, although I haven't fully checked this out.