facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.18k stars 280 forks source link

clip_grad_norm_ from fairscale downcasts to bf16 before all reduce #1092

Open glample opened 1 year ago

glample commented 1 year ago

Copied from: https://github.com/fairinternal/xlformers/issues/117

Shouldn't we remove the .to(dtype=parameters[0].dtype) from this line? https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/internal/params.py#L75 It seems weird (and it results in inaccuracies) to convert partial gradient norms to fp16/bf16 before summing them.

Context:

We use: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L621

which calculates grad norms via: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/internal/params.py#L59

which downcasts to param dtype via: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/internal/params.py#L75

before the allreduce: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L672

Spotted from looking at how unusually even grad norms look at each training step:

"g_norm": 5.6875
"g_norm": 11.1875
"g_norm": 23.0
"g_norm": 45.25
"g_norm": 89.5
"g_norm": 176.0
"g_norm": 360.0
"g_norm": 704.0
"g_norm": 720.0
"g_norm": 724.0
"g_norm": 728.0
"g_norm": 716.0
"g_norm": 724.0
"g_norm": 728.0
"g_norm": 752.0
"g_norm": 736.0
"g_norm": 728.0
"g_norm": 728.0
"g_norm": 736.0
"g_norm": 728.0
"g_norm": 728.0
"g_norm": 724.0
"g_norm": 724.0
"g_norm": 724.0
"g_norm": 732.0
"g_norm": 764.0
"g_norm": 720.0
"g_norm": 728.0
"g_norm": 728.0
"g_norm": 740.0
"g_norm": 732.0
"g_norm": 736.0
"g_norm": 704.0
"g_norm": 700.0
"g_norm": 728.0
"g_norm": 740.0
"g_norm": 724.0
"g_norm": 752.0
"g_norm": 712.0
"g_norm": 716.0
"g_norm": 724.0
"g_norm": 744.0
"g_norm": 728.0
"g_norm": 736.0
"g_norm": 720.0
"g_norm": 716.0
"g_norm": 724.0
"g_norm": 716.0
"g_norm": 720.0
"g_norm": 712.0
"g_norm": 744.0
"g_norm": 724.0
"g_norm": 708.0
"g_norm": 708.0
"g_norm": 716.0
"g_norm": 704.0
"g_norm": 712.0
"g_norm": 724.0
"g_norm": 708.0
"g_norm": 708.0
"g_norm": 728.0
"g_norm": 720.0
"g_norm": 724.0
"g_norm": 716.0
"g_norm": 712.0
"g_norm": 704.0
"g_norm": 700.0
"g_norm": 688.0
"g_norm": 692.0
"g_norm": 696.0
"g_norm": 732.0
"g_norm": 620.0
"g_norm": 1168.0
"g_norm": 1152.0
"g_norm": 1144.0
"g_norm": 1112.0
"g_norm": 1128.0
"g_norm": 1136.0
"g_norm": 1128.0
"g_norm": 1128.0
"g_norm": 1104.0
"g_norm": 1112.0
"g_norm": 1088.0
"g_norm": 1112.0
"g_norm": 1112.0
"g_norm": 1120.0
"g_norm": 1112.0
"g_norm": 1064.0
"g_norm": 1040.0
"g_norm": 1024.0
"g_norm": 1056.0
"g_norm": 1032.0
"g_norm": 1032.0
"g_norm": 1024.0
"g_norm": 1048.0
"g_norm": 1016.0
"g_norm": 1040.0
"g_norm": 1016.0
"g_norm": 936.0
"g_norm": 828.0
"g_norm": 764.0
"g_norm": 732.0
"g_norm": 692.0
"g_norm": 676.0
"g_norm": 1376.0
"g_norm": 1360.0
"g_norm": 1328.0
"g_norm": 1360.0
"g_norm": 1360.0
"g_norm": 1312.0
"g_norm": 1328.0
"g_norm": 1264.0
"g_norm": 1304.0
"g_norm": 1280.0
"g_norm": 1296.0
"g_norm": 1224.0
"g_norm": 1256.0
"g_norm": 1264.0
"g_norm": 1224.0
"g_norm": 1152.0
"g_norm": 1160.0
"g_norm": 1184.0
"g_norm": 1184.0
"g_norm": 1144.0
"g_norm": 1128.0
"g_norm": 1112.0
"g_norm": 1080.0
"g_norm": 1072.0
"g_norm": 1048.0
"g_norm": 1040.0
"g_norm": 1040.0
"g_norm": 1072.0
"g_norm": 1032.0
"g_norm": 1024.0
"g_norm": 996.0
"g_norm": 976.0
"g_norm": 988.0
"g_norm": 976.0
"g_norm": 956.0
"g_norm": 988.0
"g_norm": 944.0
"g_norm": 924.0
"g_norm": 924.0
"g_norm": 904.0
"g_norm": 1840.0
"g_norm": 1872.0
"g_norm": 1816.0
"g_norm": 1760.0
"g_norm": 1752.0
"g_norm": 1808.0
min-xu-ai commented 1 year ago

Yeah, it does seem like an unnecessary thing to do. It seems to be from this commit:

https://github.com/facebookresearch/fairscale/commit/8dc2030b09eab71c7aace9e6a674ca4ea8429346

@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.

blefaudeux commented 1 year ago

Yeah, it does seem like an unnecessary thing to do. It seems to be from this commit:

8dc2030

@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.

hey there, seeing this a bit late, no context from me really I guess that it was a type "fix" at some point

edit: I don´t really understand your link @min-xu-ai, the linked commit made sure that the norm was computed in fp32 locally (even if the type fp16 for instance), but this is not what @glample is suggesting here, right ? I'm a bit lost with this PR title

edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.

min-xu-ai commented 1 year ago

edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.

Yes, that's what I meant. Thanks a lot for the context, Ben!