pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.27k stars 167 forks source link

reproducable numerics for loss, weights and gradients for single node (8 GPUs) #593

Open weifengpy opened 1 day ago

weifengpy commented 1 day ago

by default, torchtitan use FSDP2 mixed precision (param_dtype=bfloat16, reduce_dtype=float32)

for low-precision dtypes (float8 and int8), it's nature to compare loss curve with bfloat16 and see how well they match. (also a good idea to compare weights norm and gradients norm)

for bfloat16 itself, multiple runs will yield different loss curves and the undeterminism should be understood and documented (say NCCL gradient reduction, attention, seed). Otherwise it's hard to understand if numeric differences are coming from low-precision dtypes

I plotted gradient norms, loss = sum(model.parameters.grad), using llama3-8b with 8 GPUs with deterministic model init and deterministic data loader

for bfloat16, gradients are quite different in repeated runs

Screenshot 2024-09-30 at 5 15 08 PM

turning off gradient norm clipping helps a lot, but could not explain all of the divergence

Screenshot 2024-09-30 at 5 17 06 PM

filing the issue here and hopefully it can be a good candidate for what's next

awgu commented 1 day ago

IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?

I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

weifengpy commented 1 day ago

IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?

I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

good call out!