Open weifengpy opened 1 day ago
IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?
I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?
I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
good call out!
by default, torchtitan use FSDP2 mixed precision (param_dtype=bfloat16, reduce_dtype=float32)
for low-precision dtypes (float8 and int8), it's nature to compare loss curve with bfloat16 and see how well they match. (also a good idea to compare weights norm and gradients norm)
for bfloat16 itself, multiple runs will yield different loss curves and the undeterminism should be understood and documented (say NCCL gradient reduction, attention, seed). Otherwise it's hard to understand if numeric differences are coming from low-precision dtypes
I plotted gradient norms,
loss = sum(model.parameters.grad)
, using llama3-8b with 8 GPUs with deterministic model init and deterministic data loaderfor bfloat16, gradients are quite different in repeated runs
turning off gradient norm clipping helps a lot, but could not explain all of the divergence
filing the issue here and hopefully it can be a good candidate for what's next