Closed EIFY closed 1 month ago
@EIFY I don't think this is quite the case, in an autocast context it returns float32 because it's upcast to float32 when AMP . But we aren't using this when AMP is enabled, it's used when pure float16/bfloat16 is enabled. Then it does make a difference. Even if the reduction is being done internally in float32, the affine ops will be done in low precision where as in LayerNormFp32 everything will be done in float32 regardless of the dtype.
Modern pytorch (1.10+) always performs LN in fp32:
So it's no longer necessary to use
LayerNormFp32
to explicitly cast to fp32. However, the built-intorch.nn.LayerNorm
always returns in fp32 when run under theautocast()
context, so we still need theLayerNorm
subclass to cast back. See also https://github.com/pytorch/pytorch/issues/66707#issuecomment-2028904230.