Closed ezhang94 closed 2 months ago
Let's "Rebase and merge" instead of "Squash and merge" because I have some minor commits here that creeped in and are unrelated to the AR generation.
Adding @ekellbuch for visibility of changes (and of course any review)
Primary changes
tests:
apply_ssm
using parallel scan to simple for loop implementationS5SSM.step
to singleS5SSM.__call__
additionally:
apply_ssm_with_feedthrough
intoapply_ssm
and return both the last state and the output sequenceS5SSM
docstringC_init='complex_normal'
, explicitly castC
as a complex dtypeOther changes
Some minor modifications to
simple_lm.py
andICLDataset
aux_rng_stream=['dropout']
at model level (not fully standardized yet)collate_fn
collate_fn
handles casting of arrays tojax.Array
and shaping(batch_sz,) -> (n_devices, batch_sz_per_device,)