i404788 / s5-pytorch

Pytorch implementation of Simplified Structured State-Spaces for Sequence Modeling (S5)
Mozilla Public License 2.0
59 stars 3 forks source link

Pendulum task #7

Closed looper99 closed 4 months ago

looper99 commented 5 months ago

Dear @i404788 , do you maybe have an S5 model that works with irregularly sampled data, I mean it can be trained on it, like in pendulum task?

If yes, could you share this branch?

Thanks.

i404788 commented 5 months ago

Hey again @looper99,

It's available on the main branch and all the releases (I created it initially with that goal https://github.com/lindermanlab/S5/issues/2), you can use it like so:

import torch
from s5 import S5
dim = 24*24
x = torch.rand(2, 50, dim)
model = S5(dim, 256)

# Random time steps
out = model(x, step_scale=torch.rand(*x.shape[:2]))
# Example loss
torch.log(out-x).sum().backward()

I didn't integrate it into S5Block because it ended up using too much memory to do any large-scale task here is the trace for the example with model=S5(24*24,512); x.shape=(2, 8192, 24*24) (I saw peak memory usage of ~60GB):

``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------------------------------------------------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls Source Location ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------------------------------------------------------ aten::mul 26.81% 6.490s 26.99% 6.532s 45.677ms 97.19 Gb 96.50 Gb 143 aten::bmm 10.75% 2.603s 12.83% 3.106s 1.035s 32.12 Gb 32.00 Gb 3 aten::empty 0.00% 569.000us 0.00% 569.000us 5.220us 1.12 Gb 1.12 Gb 109 aten::empty_strided 0.00% 424.000us 0.00% 424.000us 3.007us 864.98 Mb 864.98 Mb 141 aten::cat 0.10% 23.964ms 0.10% 24.113ms 927.423us 255.97 Mb 255.97 Mb 26 /media/sata0/Projects/s5-pytorch/s5/jax_compat.py(160): _interleave aten::mul 0.05% 11.732ms 0.05% 11.732ms 378.452us 133.02 Mb 133.02 Mb 31 /media/sata0/Projects/s5-pytorch/s5/s5_model.py(23): binary_operator aten::addcmul 0.06% 13.481ms 0.06% 13.481ms 421.281us 131.03 Mb 131.03 Mb 32 /media/sata0/Projects/s5-pytorch/s5/s5_model.py(23): binary_operator aten::cat 0.05% 11.396ms 0.05% 11.515ms 460.600us 128.97 Mb 128.97 Mb 25 aten::sub 0.06% 15.543ms 0.06% 15.549ms 7.774ms 96.00 Mb 96.00 Mb 2 aten::add 0.07% 15.761ms 0.07% 15.761ms 1.212ms 95.98 Mb 95.98 Mb 13 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------------------------------------------------------------------ Self CPU time total: 24.204s ```

With smaller sequence length (in the paper L=50) it works fine but each element now has it's own B_bar matrix which contributes significantly to memory.

There may be optimizations I'm missing but it should currently be equivalent to the reference code in terms of operations