mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.14k stars 908 forks source link

deprecate LayerNormFp32 #850

Closed EIFY closed 1 month ago

EIFY commented 3 months ago

Modern pytorch (1.10+) always performs LN in fp32:

For example, LayerNorm has to be done in fp32 and recent pytorch (1.10+) has been fixed to do that regardless of the input types, but earlier pytorch versions accumulate in the input type which can be an issue.

So it's no longer necessary to use LayerNormFp32 to explicitly cast to fp32. However, the built-in torch.nn.LayerNorm always returns in fp32 when run under the autocast() context, so we still need the LayerNorm subclass to cast back. See also https://github.com/pytorch/pytorch/issues/66707#issuecomment-2028904230.

rwightman commented 2 months ago

@EIFY I don't think this is quite the case, in an autocast context it returns float32 because it's upcast to float32 when AMP . But we aren't using this when AMP is enabled, it's used when pure float16/bfloat16 is enabled. Then it does make a difference. Even if the reduction is being done internally in float32, the affine ops will be done in low precision where as in LayerNormFp32 everything will be done in float32 regardless of the dtype.