FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
347 stars 48 forks source link

save mean and rstd in float32 when the inputs are in half-precision fp types #221

Closed iclementine closed 2 months ago

iclementine commented 2 months ago

PR Category

Operator

Type of Change

Bug Fix

Description

fix layer_norm_backward: save mean and rstd in float32 when the inputs are in half-precision floating point dtypes to avoid numerical instability or errors

NOTE: Aten's implementation also saves mean and rstd in fp32 in these cases

Issue

Progress

Performance

iclementine commented 2 months ago

Known Issue: Backward pass cannot be tracked by coverage.