Closed looper99 closed 7 months ago
Hey,
apply_ssm
is the parallel formulation of S5 if you want to carry state you should use forward_rnn
https://github.com/i404788/s5-pytorch/blob/f0fb13226e13cd508deecc5dda280d99ae5bdda1/s5/s5_model.py#L227 with initial_state
https://github.com/i404788/s5-pytorch/blob/f0fb13226e13cd508deecc5dda280d99ae5bdda1/s5/s5_model.py#L313 as prev_state.
This code path hasn't been tested too much though so there might be bugs (let me know).
For training you usually want to use the parallel formulation which increases the speed of training and reduces memory (if using regular autograd), and use forward_rnn for inference speed/memory.
If you want to mix them I think you'll need to extract the last element of xs
from apply_ssm
:
- return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du
+ return torch.vmap(lambda x: (C_tilde @ x).real)(xs) + Du, xs[-1]
I understand, but what I meant by this question is if I now return xs[-1] how can it be reused in apply_ssm because I tried: Lambda_bars = Lambda_bars * prev_state (where prevstate is returned xs[-1]) , xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements)) But this didn't work.
What I am talking about is a functionality as in S4's state forwarding mode: https://github.com/HazyResearch/state-spaces/tree/main/models/s4#state-forwarding
Ah I see, referencing the paper it seems like you would need to inject it as the initial state of associative_scan
however the jax impl I ported doesn't actually support that.
I think Lambda_bars[0] = Lambda_bars[0] * prev_state
w/ prev_state = xs[-1]
should do the equivalent.
Note this would be after it has been tiled (just before the first associative_scan call)
I see, so:
if Lambda_bars.ndim == 1: # Repeat for associative_scan
Lambda_bars = Lambda_bars.tile(input_sequence.shape[0], 1)
Lambda_bars[0] = Lambda_bars[0] * prev_state
_, xs = associative_scan(binary_operator, (Lambda_bars, Bu_elements))
should do it?
Then I can have apply_ssm with prev_state: apply_ssm(Lambda_bars: torch.Tensor, B_bars, C_tilde, D, input_sequence, prev_state, bidir: bool = False) and carry it on: LOOP: x, states = s5(x, states) END:
Yes, with the other return patch, that should be correct. I'll probably add this functionality to the main repo after it's validated.
Yes, of course, returning result, xs[-1]. Thank you once again for fast replies and resolving my problem.
We too are very interested in this. Is there any possibility of getting this implemented?
Best regards!
@kallegrens I'll look into it somewhere this week, I'll need some time to validate that it's working as expected
Fantastic! Thank you!
@kallegrens version 0.2.0 has been released with state carrying. See https://github.com/i404788/s5-pytorch/blob/2da6de720df0e033b9bbb1930f8cf1a0dcb9a077/s5/s5_model.py#L411-L424.
@looper99 not sure if it's still relevant for you but it seems the way to carry was different than I previously thought; so you may want to check the changes.
If you find any issues let me know. Experimentally it should be correct, but the combination of zero-pad and replacing all state values is somewhat confusing to me, so it may be that the carrying is still not quite equivalent to the rnn formulation. I will try to test further a bit later.
@i404788 thanks for the info. Why is that one now the good one? Btw, I was using previous one and it worked just great, but this current one is not giving me good results as the previous one we did.
@looper99 Thanks for testing so quickly and reporting on the initial proposal.
The issue with Lambda_bars[0] *= state
is that it deviates from the equivalent RNN and parallel output. Basically the RNN and parallel formulation agree on the last state and last element quite closely (eps<1e-5) while the chunked parallel output deviates a lot which means when trained with parallel it wouldn't extrapolate (it might still fine work within the chunk though).
I've now figured out what went wrong and released v0.2.1 on github and pypi; It turns out it needed to have the first step with state calculated manually (or concat the state to the front, but that's less efficient). I expect you'll get better performance than the original proposal with this version.
@i404788 Thank you. I tried and this worked, but just slightly better.
Dear, how to use prev_state in apply_ssm function since I see it is now purely forward? I would ideally want to: x, states = s5(x, states), where apply_ssm carries state such that I can train with memory.