i404788 / s5-pytorch

Pytorch implementation of Simplified Structured State-Spaces for Sequence Modeling (S5)
Mozilla Public License 2.0
53 stars 2 forks source link

How to carry state in apply_ssm? #3

Closed looper99 closed 7 months ago

looper99 commented 7 months ago

Dear, how to use prev_state in apply_ssm function since I see it is now purely forward? I would ideally want to: x, states = s5(x, states), where apply_ssm carries state such that I can train with memory.

i404788 commented 7 months ago

Hey,

apply_ssm is the parallel formulation of S5 if you want to carry state you should use forward_rnn https://github.com/i404788/s5-pytorch/blob/f0fb13226e13cd508deecc5dda280d99ae5bdda1/s5/s5_model.py#L227 with initial_state https://github.com/i404788/s5-pytorch/blob/f0fb13226e13cd508deecc5dda280d99ae5bdda1/s5/s5_model.py#L313 as prev_state.

This code path hasn't been tested too much though so there might be bugs (let me know).

For training you usually want to use the parallel formulation which increases the speed of training and reduces memory (if using regular autograd), and use forward_rnn for inference speed/memory.

i404788 commented 7 months ago

If you want to mix them I think you'll need to extract the last element of xs from apply_ssm:

- return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du
+ return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1]
looper99 commented 7 months ago

I understand, but what I meant by this question is if I now return xs[-1] how can it be reused in apply_ssm because I tried: Lambda_bars = Lambda_bars * prev_state (where prevstate is returned xs[-1]) , xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements)) But this didn't work.

What I am talking about is a functionality as in S4's state forwarding mode: https://github.com/HazyResearch/state-spaces/tree/main/models/s4#state-forwarding

i404788 commented 7 months ago

Ah I see, referencing the paper it seems like you would need to inject it as the initial state of associative_scan however the jax impl I ported doesn't actually support that.

I think Lambda_bars[0] = Lambda_bars[0] * prev_state w/ prev_state = xs[-1] should do the equivalent. Note this would be after it has been tiled (just before the first associative_scan call)

looper99 commented 7 months ago

I see, so:

if Lambda_bars.ndim == 1: # Repeat for associative_scan
        Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1)

Lambda_bars[0] = Lambda_bars[0] * prev_state
_, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))

should do it?

Then I can have apply_ssm with prev_state: apply_ssm(Lambda_bars: torch.Tensor, B_bars, C_tilde, D, input_sequence, prev_state, bidir: bool = False) and carry it on: LOOP: x, states = s5(x, states) END:

i404788 commented 7 months ago

Yes, with the other return patch, that should be correct. I'll probably add this functionality to the main repo after it's validated.

looper99 commented 7 months ago

Yes, of course, returning result, xs[-1]. Thank you once again for fast replies and resolving my problem.

kallegrens commented 2 months ago

We too are very interested in this. Is there any possibility of getting this implemented?

Best regards!

i404788 commented 2 months ago

@kallegrens I'll look into it somewhere this week, I'll need some time to validate that it's working as expected

kallegrens commented 2 months ago

Fantastic! Thank you!

i404788 commented 2 months ago

@kallegrens version 0.2.0 has been released with state carrying. See https://github.com/i404788/s5-pytorch/blob/2da6de720df0e033b9bbb1930f8cf1a0dcb9a077/s5/s5_model.py#L411-L424.

@looper99 not sure if it's still relevant for you but it seems the way to carry was different than I previously thought; so you may want to check the changes.

If you find any issues let me know. Experimentally it should be correct, but the combination of zero-pad and replacing all state values is somewhat confusing to me, so it may be that the carrying is still not quite equivalent to the rnn formulation. I will try to test further a bit later.

looper99 commented 2 months ago

@i404788 thanks for the info. Why is that one now the good one? Btw, I was using previous one and it worked just great, but this current one is not giving me good results as the previous one we did.

i404788 commented 2 months ago

@looper99 Thanks for testing so quickly and reporting on the initial proposal.

The issue with Lambda_bars[0] *= state is that it deviates from the equivalent RNN and parallel output. Basically the RNN and parallel formulation agree on the last state and last element quite closely (eps<1e-5) while the chunked parallel output deviates a lot which means when trained with parallel it wouldn't extrapolate (it might still fine work within the chunk though).

I've now figured out what went wrong and released v0.2.1 on github and pypi; It turns out it needed to have the first step with state calculated manually (or concat the state to the front, but that's less efficient). I expect you'll get better performance than the original proposal with this version.

looper99 commented 2 months ago

@i404788 Thank you. I tried and this worked, but just slightly better.