srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
460 stars 61 forks source link

[Draft] First attempt at a DSS extension #49

Closed srush closed 2 years ago

srush commented 2 years ago

Implementation of the DSS model. Checked for MNist classification / MNist generation.

srush commented 2 years ago

Hi @ag1988 ,

We were going to add DSS addendum to our writeup of S4. Let us know if you have any tips or tricks that you want us to include.

Also just to check. In the paper I didn't see an easy way to recover the RNN form of the model. Is that true? Are there other parameterizations that allow it?

(oh, think this might just be because of the papers parameterization. looks like we maybe can recover the RNN.)

ag1988 commented 2 years ago

Hi @srush,

Thank you for including DSS in your wonderful tutorial! The implementation looks great.

  1. The init for Lambda can be double-checked here.
  2. In the current code, by default I use separate Deltas for the real and imaginary parts of Lambda. I.e. instead of directly using (Lambda.real * Delta + 1j * Lambda.imag * Delta) , I use (Lambda.real*Delta_1 + 1j * Lambda.imag*Delta_2). This doesn't help a lot but I just wanted to mention this just in case.

Otherwise your code looks wonderful! Thanks again for including DSS :)

About recovering RNN reps, do you mean the following: given inputs u = (u1,..,uL), kernel K = (k1,..,kL) , output y= (y1.,,,yL) how to recover the states? The set of states is not unique as the state space (A,B,C) ~ (V*AV, V*B, CV) in terms of outputs. But the recurrence x{k} = bar{A}.x{k-1} + \bar{B}.uk and x{k} = V*bar{A}V.x_{k-1} + V*\bar{B}.u_k will have different states (transformed by V*).

srush commented 2 years ago

Thanks for responding!

For (1) I was thinking of not training Lambda at all, do you think that would make it much worse? We didn't train A in our S4 code and it seemed basically fine.

For (2) the RNN mode, I wasn't sure how to recover A, B, C at all from Lambda and W. For generation tasks, it is really nice to be able to have these matrices.

My thought was if we don't train Lambda, we maybe have V from the eigenvectors, and could learn B and C instead of W?

siddk commented 2 years ago

Hey @ag1988 - so interesting results! To better test the implementation, we ran the DSS implementation we wrote up (in this PR) and our S4 implementation (with fixed A matrix) on the sequential CIFAR classification task (given an image as a sequence of pixels, predict class).

The original S4 paper reports SOTA at ~91% accuracy (but that's with training A, adding an LR schedule, etc.). So all in all, a strong result for the simple DSS model! Nice!

ag1988 commented 2 years ago

@srush let me get back to you on the states question tomorrow. Sorry for the delay :pray: .

@siddk Woo hoo! That's really cool Sidd! :tada:

ag1988 commented 2 years ago

Hi @srush ,

Here's a derivation of how to compute the states for the state space parameterized by DSS. I know its a bit strange that original state space here is depending on Delta but this is just to show you that its indeed possible to construct a valid set of states. Let me know if there's an error.

image

srush commented 2 years ago

Thanks, @ag1988, this is super clear. Going to add it to our implementation.

It's really convenient that it factors this way. N separate RNNs. I think Jax can actually optimize for that.

ag1988 commented 2 years ago

Sounds great Sasha! Someone asked me this question before but it wasn't super clear to me why recovering the states would be helpful as they're not even unique. But after you asked, looks like it'll be helpful to include it in the DSS write-up for the interested reader. So thanks for asking this question.

On Wed, Apr 6, 2022, 9:02 PM Sasha Rush @.***> wrote:

Thanks, @ag1988 https://github.com/ag1988, this is super clear. Going to add it to our implementation.

It's really convenient that it factors this way. N separate RNNs. I think Jax can actually optimize for that.

— Reply to this email directly, view it on GitHub https://github.com/srush/annotated-s4/pull/49#issuecomment-1090972823, or unsubscribe https://github.com/notifications/unsubscribe-auth/AID3IGHX6VTIULHMJI65QXDVDYXY7ANCNFSM5ST3BW7Q . You are receiving this because you were mentioned.Message ID: @.***>