srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
460 stars 61 forks source link

Implement the DSS RNN parameterization #51

Closed srush closed 2 years ago

srush commented 2 years ago

Implements the RNN version of the DSS model in the format of S4.

Test confirms it gives the same result. Training a mnist generator now to confirm.

This is based on this writeup from @ag1988 https://github.com/srush/annotated-s4/pull/49#issuecomment-1090963170

As he notes the hidden states are not unique, but they do let use do generation efficiently.

srush commented 2 years ago

@ag1988 whoa, this worked out-of-the-box with our old code.

image

ag1988 commented 2 years ago

Looks cool!!

PS. I'll actually go over your tutorial again to understand this part :sweat_smile:

srush commented 2 years ago

Oh not that exciting. We just give model the white part and have it generate the red part, pixel by pixel. If we did it regularly (with K), we would need to feed in all the previous pixels each time to compute the next one. With the SSM form, we can compute 1 pixel at a time using the cached the hidden state.

ag1988 commented 2 years ago

Thank you for the clarification. I now get the task and see why having the states is useful for autoregressive decoding.