Closed looper99 closed 4 months ago
Hey again @looper99,
It's available on the main branch and all the releases (I created it initially with that goal https://github.com/lindermanlab/S5/issues/2), you can use it like so:
import torch
from s5 import S5
dim = 24*24
x = torch.rand(2, 50, dim)
model = S5(dim, 256)
# Random time steps
out = model(x, step_scale=torch.rand(*x.shape[:2]))
# Example loss
torch.log(out-x).sum().backward()
I didn't integrate it into S5Block because it ended up using too much memory to do any large-scale task here is the trace for the example with model=S5(24*24,512); x.shape=(2, 8192, 24*24)
(I saw peak memory usage of ~60GB):
With smaller sequence length (in the paper L=50) it works fine but each element now has it's own B_bar matrix which contributes significantly to memory.
There may be optimizations I'm missing but it should currently be equivalent to the reference code in terms of operations
Dear @i404788 , do you maybe have an S5 model that works with irregularly sampled data, I mean it can be trained on it, like in pendulum task?
If yes, could you share this branch?
Thanks.