lindermanlab / S5

MIT License
259 stars 45 forks source link

add autoregressive capability to `S5Operator`, `Block`, and `LMBackbone` #20

Closed ezhang94 closed 3 months ago

ezhang94 commented 3 months ago

All step implementations match their corresponding __call__ implementations up to a tolerance of (atol=1e-8, rtol=1e-7) in double-precision mode.

Tolerance needs to be raised to (atol~=5e-3, rtol=0) in single-precision mode due to various sources of numerical imprecision. Sources of imprecision have been documented where identified (via manual stepping-through). Tests are forced to operate in double-precision mode to control for these types of errors and reveal implementation errors.

The parallel vs. autoregressive implementations of LMBackbone with multiple (>1) S5Operator layers was found to have discrepancies after the first block; this is believed to be due to imprecision in the carried state. Please refer to the test_simply_lm.py::test_lmbackbone_step function for more details.

cc: Kelly for visibility

jimmysmith1919 commented 3 months ago

This all looks good to me.