lindermanlab / S5

MIT License
259 stars 45 forks source link

Recursive S5 implementation #4

Closed EelcoHoogendoorn closed 1 year ago

EelcoHoogendoorn commented 1 year ago

Hi,

Im interested in a recursive variant of the S5 approach, seeing as how I want to apply this to inherently sequential/interleaved tasks such as control. I think S5 could be a great fit there, and the single state space strikes me as more general than the S4 formulation. I intend to combine this as a linear pre-filter to feed into a nonlinear gated unit, along the lines of the original HIPPO paper; though I bet many variants are possible.

It occurs to me that this functionality should be trivial to build on top of the parallel scan implementation provided in this repo; the below appears to me like it should 'just work'; but if anyone has comments and suggestions, they would be very welcome.

from s5.ssm import S5SSM
class S5SSMRecursive(S5SSM):

    def __call__(self, x, u):
        C = 2*self.C_tilde if self.conj_sym else self.C_tilde
        x = self.Lambda_bar * x + self.B_bar @ u
        y = (C @ x).real + self.D * u
        return x, y

    def init_carry(self, key=None):
        x = np.zeros(self.P)
        return x + 1j * x
EelcoHoogendoorn commented 1 year ago

I suppose it should be trivial to write a test of correctness by just comparing it against an input sequence of length 1. Will get to that shortly.

If chaining together chunks of parallel executed iterations, I suppose one sensible way to fit that in would be to add a line:

    Lambda_elements = Lambda_elements.at[0].set(carry_x * Lambda_elements[0])

As might be useful in the context of a chatbot type interface or the like.

EDIT: Thinking about it; I suppose that means the implicit x/carry init in the scanned case is a vector of ones, not zeros. EDIT2: ah no; indeed initializing the carry with zero gives identical results between the recurrent and scanned implementation. So yeah; the above does appear to be correct.

jimmysmith1919 commented 1 year ago

Thanks for reaching out!

You are correct that the parallel scan in apply_ssm() currently just automatically assumes x0=0. However, if we wanted to chain together parallel computations we would like the parallel scan to be stateful as well.

Assuming apply_ssm() now takes in an argument x0, you can add an extra line after we define Bu_elements:

Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence)
Bu_elements = Bu_elements.at[0].add(Lamba_bar * x0) #add this line

Now the first element of xs returned by the parallel scan will be Lamba_bar*x0 + Bu_1 , which is what we want.

In addition, with this stateful modification, one way to get the recurrent update you were asking for in your original question would be to just use apply_ssm() on a length one sequence each step.

Though you could also use a separate sequential implementation as you have suggested or could even use Jax's sequential scan as used in the annotated S4 codebase which our codebase borrows heavily from: https://github.com/srush/annotated-s4/blob/main/s4/s4.py#L439. Note one extra point in this linked code is that when they run the SSM sequentially, they cache the discretized parameters to keep from performing the discretization every step (assuming constant step size of course) which you might want to do as well.

Our current repository is just a minimal implementation to perform the experiments in the paper, but we hope to add some of this extra functionality ourselves soon!