All step implementations match their corresponding __call__ implementations up to a tolerance of (atol=1e-8, rtol=1e-7)in double-precision mode.
Tolerance needs to be raised to (atol~=5e-3, rtol=0) in single-precision mode due to various sources of numerical imprecision. Sources of imprecision have been documented where identified (via manual stepping-through). Tests are forced to operate in double-precision mode to control for these types of errors and reveal implementation errors.
The parallel vs. autoregressive implementations of LMBackbone with multiple (>1) S5Operator layers was found to have discrepancies after the first block; this is believed to be due to imprecision in the carried state. Please refer to the test_simply_lm.py::test_lmbackbone_step function for more details.
All
step
implementations match their corresponding__call__
implementations up to a tolerance of(atol=1e-8, rtol=1e-7)
in double-precision mode.Tolerance needs to be raised to
(atol~=5e-3, rtol=0)
in single-precision mode due to various sources of numerical imprecision. Sources of imprecision have been documented where identified (via manual stepping-through). Tests are forced to operate in double-precision mode to control for these types of errors and reveal implementation errors.The parallel vs. autoregressive implementations of
LMBackbone
with multiple (>1)S5Operator
layers was found to have discrepancies after the first block; this is believed to be due to imprecision in the carried state. Please refer to thetest_simply_lm.py::test_lmbackbone_step
function for more details.cc: Kelly for visibility