Open glample opened 1 year ago
Yeah, it does seem like an unnecessary thing to do. It seems to be from this commit:
https://github.com/facebookresearch/fairscale/commit/8dc2030b09eab71c7aace9e6a674ca4ea8429346
@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.
Yeah, it does seem like an unnecessary thing to do. It seems to be from this commit:
@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.
hey there, seeing this a bit late, no context from me really I guess that it was a type "fix" at some point
edit: I don´t really understand your link @min-xu-ai, the linked commit made sure that the norm was computed in fp32 locally (even if the type fp16 for instance), but this is not what @glample is suggesting here, right ? I'm a bit lost with this PR title
edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.
edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.
Yes, that's what I meant. Thanks a lot for the context, Ben!
Copied from: https://github.com/fairinternal/xlformers/issues/117
Shouldn't we remove the
.to(dtype=parameters[0].dtype)
from this line? https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/internal/params.py#L75 It seems weird (and it results in inaccuracies) to convert partial gradient norms tofp16
/bf16
before summing them.Context:
We use: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L621
which calculates grad norms via: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/internal/params.py#L59
which downcasts to param dtype via: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/internal/params.py#L75
before the allreduce: https://github.com/facebookresearch/fairscale/blob/ee647b976cf4c8fdd37bc9ae3fd6331d225ba2a0/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L672
Spotted from looking at how unusually even grad norms look at each training step: