Closed srush closed 2 years ago
@ag1988 whoa, this worked out-of-the-box with our old code.
Looks cool!!
PS. I'll actually go over your tutorial again to understand this part :sweat_smile:
Oh not that exciting. We just give model the white part and have it generate the red part, pixel by pixel. If we did it regularly (with K), we would need to feed in all the previous pixels each time to compute the next one. With the SSM form, we can compute 1 pixel at a time using the cached the hidden state.
Thank you for the clarification. I now get the task and see why having the states is useful for autoregressive decoding.
Implements the RNN version of the DSS model in the format of S4.
Test confirms it gives the same result. Training a mnist generator now to confirm.
This is based on this writeup from @ag1988 https://github.com/srush/annotated-s4/pull/49#issuecomment-1090963170
As he notes the hidden states are not unique, but they do let use do generation efficiently.