Closed jasonkrone closed 1 week ago
Typically the optimizer needs fp32 weights to ensure good convergence (although see the fp8 optimizer in https://arxiv.org/pdf/2310.18313, which uses per-tensor scaling to get away with fp16 weights). For data-parallel bf16 training, there's a memory-performance tradeoff that depends on your use-case:
I'd recommend option 1 at first since it just requires running with torch.autocast
. Option 3 may be better (as you mention, it's what's used in Megatron-LM), but it depends on many factors and ultimately there's no alternative to benchmarking.
Option 3 is more attractive with FSDP since the fp32 weights can be sharded, while the gathered weights can be in bf16. This should just be a matter of configuring FSDP with torch.distributed.fsdp.MixedPrecision
. I want to say it's sufficient to initialize the model in fp32 and the MixedPrecision
with param_dtype=torch.bfloat16
and reduce_dtype=torch.float32
, although I haven't fully checked this out.
Hi there,
Given a fixed compute budget, what's recommended:
pure bfloat16 training, where all computations done in bfloat16
bfloat16 MixedPrecision training, where the optimizer maintains a copy of the network weights in fp32 that are updated with optimizer.step()
P.S. Appreciate all your work on TransformerEngine