Open stergiosba opened 1 year ago
Hey there! I think this would definitely be a worthy addition.
I've been putting these more sophisticated architectures as full examples in the documentation, e.g. see here for BERT. (So I'm going to label this issue under "documentation".)
Certainly, I can try and provide a bit of guidance on the sharp bits.
but I am already deeply invested in Equinox to turn back :)
:D
Hey @stergiosba, nice that you're working on S4s with equinox.
I have my own implementation here, which is also based on the original FLAX implementation. Happy to contribute here as well :)
Hey @yardenas
I am glad you have already done some steps into this, cause frankly time is of the essence.
I took a brief look at your imlementation and its pretty close to what I have set up. I will test to see if I get the same values as you.
So S4 layers are made to handle a one-dimensional sequence of size L and produce a one-dimensional output.
This means that if you need to pass multiple "parallel" sequences you have to stack the layers. The following is stated in the Flax implementation and I quote: "Since our SSMs operate on scalars, we make H different, stacked copies (H different SSMs!) with different parameters".
In the Flax implementation, they use Flax vmap to define these copies. A point of confusion is whether the same is achieved with the following code:
key = jax.random.PRNGKey(0)
N = 50 #state_dimension [same for all layers]
L = 500 #sequence_dimension [same for all layers]
H = 64 #stacked layers of S4
keys = jax.random.split(key, H)
make_s4 = lambda k: S4Layer(N, L, key=k)
model_S4 = eqx.filter_vmap(make_s4)(keys)
This would be inside a class of course like the example in Equinox docs: improve-compilation-speed-with-scan-over-layers.
Note: The Flax docs page is currently offline but essentially flax vmap copies a nn.Module a specified amount of times.
That looks correct to me!
@stergiosba, one thing that was a bit annoying was to pass the ssm around whenever I wanted to use the recurrent mode of S4. IIRC, in the original FLAX implementation bypass this by caching the ssm. A similar approach here would be to use equinox's stateful modules. Maybe we can use this feature to make the API more user-friendly
I don't think you'll need to cache anything -- write it like GRUs/LSTMs already are in Equinox.
@yardenas, I saw that was the case. I am trying to write is as Patrick says. Works for the 1D sequence. Will test today for the H-Dimensional sequence.
@patrick-kidger, @stergiosba, I think I didn't explain myself good enough and there's a slight misunderstanding😅😅 Let me dive a bit deeper into the details, hopefully giving you a better idea of what I mean.
The S4 model is paramterized by the ssm's matrices, however, when you run a forward pass, you need to discretize the (continuous) linear system (discrete_DPLR
in the FLAX implementation).
Since it is static and changes only when you update the model, you can (and should) compute the dicretized ssm only once before running full sequence predictions. This is opposed to computing the same thing on each and every prediction step in a sequence.
In the FLAX implementation, the authors avoid evaluating the (same) discretized ssm by computing it once and then caching it---once you have it cached, you don't need to pass it around alongside the S4 state.
In contrast to the discretized ssm, this state changes between time steps, so you'd need to pass it around as you'd do in GRUs and LSTMs.
In my implementation, I did pass it around together with the state, which caused the interface to be a bit annoying. This was before stateful operations were introduced to equinox, so I think there's a room for improvement. Hope this makes sense to you🤝
@yardenas In your code, the forward pass of the model in RNN form is the following:
@jax.vmap
def __call__(self, x_k_1, u_k, ssm):
ab, bb, cb = ssm
if u_k.ndim == 0:
u_k = u_k[None]
x_k = ab @ x_k_1 + bb @ u_k
y_k = cb @ x_k
return x_k, (y_k + self.d * u_k).real.squeeze(-1)
I thought this is what you were referring to, or am I wrong?
@stergiosba, exactly, so whenever I use the RNN mode (typically for forecasting), I compute the ssm and then pass it to this function in each step in autoregression.
This becomes a bit awkward if you want to quickly swap S4 with another RNN architecture (like GRUs for example) because now you're not only passing an input (u_k
) and state (x_k_1
), but also the ssm.
To alleviate this, I see three options:
__call___(...)
above). This can be potentially slow since it involves solving multiple linear equations (one per ssm)eqx.nn.State
). This requires closer treatment from the developer, but IMO user-friendlier.class S4Layer(ex.Module):
log_step: Array
Lambda_reArray
Lambda_im: Array
P: Array
B: Array
C: Array
D: Array
cell_size: int = eqx.field(static=True) #N
sequence_size: int = eqx.field(static=True) #L
ssm: Tuple = ()
conv_kernel: Array = jnp.empty(shape=(1,))
decode: bool = True
def __init__(
self,
cell_size: int,
sequence_size: int,
decode: bool,
key: PRNGKeyArray
):
self.cell_size = cell_size
self.sequence_size = sequence_size
# This part is handled by HIPPO initialization
self.Lambda_re, self.Lambda_im, self.P, self.B = hippo_initializer(cell_size)
Lambda = jnp.clip(Lambda_re, None, -1e-4) + 1j * Lambda_im
# This part is handled by randomized C and D as ones.
C = normal(stddev=0.5**0.5)(key, shape=(self.cell_size,2))
C = C[..., 0] + 1j*C[..., 1]
self.D = ones(key, (1,), jnp.float32)
self.log_step = log_step_initializer(key, shape=(1,))
self.decode = decode
if self.decode:
# RNN mode, discretize
self.ssm = discrete_DPLR(
Lambda,
P,
P,
B,
C,
jnp.exp(self.log_step),
self.sequence_size,
)
else:
# CNN mode, calculate convolution kernel
self.conv_kernel = kernel_DPLR(
Lambda,
P,
P,
B,
C,
jnp.exp(self.log_step),
self.sequence_size,
)
def __call__(self, input: Array, *, key: Optional[PRNGKeyArray] = None):
if self.decode:
# RNN Mode
def rnn_mode(input, ssm):
def scan_SSM(input, x0):
def step(x_k_1, u_k):
x_k = ssm[0] @ x_k_1 + ssm[1] @ u_k
y_k = ssm[2] @ x_k
return x_k, y_k
return lax.scan(step, x0, input)
_, output = scan_SSM(input[:, None],
jnp.zeros((self.cell_size), dtype=jnp.complex64))
return output.flatten().real
vscan = vmap(rnn_mode,in_axes=(0,0), out_axes=0)
return vscan(input, self.ssm) + self.D * input
else:
# CNN Mode
return causal_convolution(input, self.conv_kernel).real + self.D * input
To use this class you do the following:
key = jax.random.PRNGKey(1)
N = 64 #cell_size
L = 5000 #sequence_size
H = 256 #number of parallel sequences i.e. height of S4Block
keys = jax.random.split(key, H)
#dummy input of correct shape=(H,L)
u = jax.random.normal(key, shape=(H,L))
#CNN mode
make_s4 = lambda k: S4Layer(N, L, decode=0, key=k)
model_S4_cnn = eqx.filter_vmap(make_s4)(keys)
y_cnn = model_S4_cnn(u)
#RNN mode
make_s4 = lambda k: S4Layer(N, L, decode=1, key=k)
model_S4_rnn = eqx.filter_vmap(make_s4)(keys)
y_rnn = model_S4_rnn(u)
The good news:
These two work with no serious issues and they produce the same output for H parallel sequences of size L each.
Now the issues:
1) Of course there is this uglyssm: Tuple = ()
that explodes when used in CNN mode.
2) Similarly conv_kernel: Array = jnp.empty(shape=(1,))
explodes in RNN mode.
3) Points 1/2 can be solved if we initialize both modes and use the one needed (i.e left to the user). Anything better?
4) The CNN mode is fast: 40.3 ms ± 2.73 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) for a state N=64, H=256 parallel sequences of L=5000 timesteps. The RNN mode is not fast, for the same sequence it takes 6 seconds per loop. Any suggestions?
5) It takes approximately 5 seconds to compile the model for the sequence above (1.1 M params) on my laptop's 11th Gen Intel Core i5-1135G7 @ 2.40GHz. Is there a way we can get an accurate measurement of the compilation time for the model?
Issue #649 and PR #656 are relevant to this issue.
I've gone ahead and implemented an upgraded version of S4 called the Linear Recurrent Unit (LRU) in the memorax library, which uses Equinox modules. I'll probably get around to adding S5 at some point as well, but empirically I haven't found any cases where S4/S5 beats the LRU.
Oh this is awesome!
We should add this to https://github.com/matteoguarrera/awesome-equinox Which I'd like to link to from the Equinox docs, now that the list is starting to get fleshed out.
(Anyone else in this thread -- please do send any projects you think are interesting to the list above!)
Hey Patrick,
I have been using Equinox for one of my projects and up until now it has helped immensely in using JAX effectively and seamlessly.
For this project, I was going to use the model of Structured State Space Sequence (S4) from Gu et Al 2022. In short in the original paper that introduces S4, the authors compared it to Attention-based architectures in time-series analysis problems with long-range time dependences in the data (e.g. EEG signals, Audio recognition, etc). They showed that S4 outperformed every other model. I think it would be a very nice addition to the package for all those who might want to use S4+Equinox in the future.
Without going into further details, a new nn Layer must be created. I am unsure if you are familiar with the S4 concept or if you even have time to implement it. The original implementation of the Layer is written in PyTorch but they also provided a Flax implementation. I could use the Flax implementation as it is but I am already deeply invested in Equinox to turn back :). Hence, I started rewriting it for Equinox. To this end, I have two questions for you.
1) Would this be a worthy addition to the main package? 2) If need be could you provide some guidance for the sharp bits?
Best, Stergios