NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Apache License 2.0
1.6k stars 255 forks source link

Pure bfloat16 vs. mixed precision bfloat16: what's recommended? #950

Closed jasonkrone closed 1 week ago

jasonkrone commented 1 week ago

Hi there,

Given a fixed compute budget, what's recommended:

  1. pure bfloat16 training, where all computations done in bfloat16

  2. bfloat16 MixedPrecision training, where the optimizer maintains a copy of the network weights in fp32 that are updated with optimizer.step()

    • this appears to be the setup used for most of the LLM examples in Megatron-LM, which was used to train the newly released Nemotron-4 model. I presume MixedPrecision was used for Nemotron-4; however, the technical report did not explicitly clarify if it was pure bfloat16 or MixedPrecision bfloat16.

P.S. Appreciate all your work on TransformerEngine

timmoon10 commented 1 week ago

Typically the optimizer needs fp32 weights to ensure good convergence (although see the fp8 optimizer in https://arxiv.org/pdf/2310.18313, which uses per-tensor scaling to get away with fp16 weights). For data-parallel bf16 training, there's a memory-performance tradeoff that depends on your use-case:

  1. fp32 weights, compute in bf16: performance overhead from casting to bf16
  2. fp32 weights, compute in fp32: probably better convergence, but slow performance and extra memory usage from storing activations in fp32
  3. both bf16 and fp32 weights, compute in bf16: can amortize casts over multiple gradient accumulation steps, but extra memory usage and code complexity from storing two copies of the weights

I'd recommend option 1 at first since it just requires running with torch.autocast. Option 3 may be better (as you mention, it's what's used in Megatron-LM), but it depends on many factors and ultimately there's no alternative to benchmarking.

Option 3 is more attractive with FSDP since the fp32 weights can be sharded, while the gathered weights can be in bf16. This should just be a matter of configuring FSDP with torch.distributed.fsdp.MixedPrecision. I want to say it's sufficient to initialize the model in fp32 and the MixedPrecision with param_dtype=torch.bfloat16 and reduce_dtype=torch.float32, although I haven't fully checked this out.