state-spaces / s4

Structured state space sequence models
Apache License 2.0
2.46k stars 295 forks source link

Conceptual Questions regarding S4/HiPPO #40

Closed oezyurty closed 2 years ago

oezyurty commented 2 years ago

Dear authors/contributors,

First of all, thank you so much for publishing such a great work. I think it is really inspirational and we will see this model (or its variants) being deployed to solve variety of real-world problems in the next years.

I tried to go through your most recent papers starting from HiPPO, and I would like to kindly ask conceptual questions to deepen my understanding. As I couldn't find different sources of information other than your papers (and a couple of your recorded talks on Youtube and Annotated S4), I think this could be an appropriate place to ask those questions. If you prefer any other discussion platform, please let me know.

PS: These questions turned out to be a bit longer than I intended, but I don't expect you to clarify them all at once :)

  1. A matrix <—> Polynomial basis : From my understanding about your HiPPO paper, you derive the A matrix for various measures and polynomial bases. Therefore, for a given A (and hence given polynomial basis), we know how to reconstruct the original signal u(t) based on the state/coefficients x(t). My question is: What does the model learn when we initialize A as HiPPO but train it over time (i.e. when A is not fixed)? In other words, how does the polynomial basis change in that sense and how does the model have the ability to reconstruct the original signal u(t) with varying A?
  2. Learning the step size: In annotated S4, the step size is another parameter that’s learned through the training. (I am not sure if you do the same as I couldn't go over your code yet)
    • May I ask the intuition for why we learn this step size and what is its potential effect(s)? For instance, if we use a measure that is exponentially decaying over time, can we say that larger step size leads to prioritizing more recent history and smaller step size is better for giving more weight to a distant past (because its weight will decay smaller)?
    • If we work on a signal that has a natural sense of time (i.e. ECG signal) should we still make step size trainable (in first and all the intermediate layers) since the actual formulation (to my understanding) has no notion of the units of step size (e.g. seconds or days etc.)?
  3. Irregular sampling of time series. I am convinced by the continuous-time view of S4 that it can naturally handle the irregularly-sampled time series of an underlying continuous dynamics. However, I am confused by the discretization step where we leverage convolution for training and recurrence for fast inference. If I have an irregular time series, how can I train S4?
    • Small comment: I think if the training data is regularly sampled, we can still handle irregular time series in real-time inference based on the bilinear transform of A_bar, B_bar etc. into their continuous equivalent. Is that true?
  4. The effect of "deep" S4 layers. In Figure 2 of your paper “Efficiently Modeling Long Sequences with Structured State Spaces”, we see the visualization of the kernels for Path-X task for the first and last layers. We see that (mostly) first layers are for local context vs. last layers are for more global context. Why is it the case if HiPPO offers continuous-time memorization? In other words, why can’t it memorize the distant past in the first layers and why does it need stacking more layers to aggregate more context from the past? I assume it is related to a chosen measure and/or the step size itself, but I am really curious about your opinion.
    • For deep CNN-related models, we have the explanation that the receptive field grows with stacking more and more layers. (Field grows exponentially with dilated convolutions like TCN, and linearly for some other types). Is there any analogy or similar explanation for S4?

It is a great pleasure for me to know more about your exciting work. Many thanks in advance. I would be also happy to know if there are other resources that you can suggest.

albertfgu commented 2 years ago

I apologize for replying so late; this is an appropriate place to ask questions, but I spent several weeks/months releasing the new preprints and the corresponding V3 of this codebase. The preprints hopefully provide some new technical content to help with such conceptual questions

  1. In general, the way I think about HIPPO is that it defines a set of basis functions $e^{tA}B$ (note that this is a length-N vector of functions). Depending on properties of these basis functions, they may allow reconstructing the original signal. Note that these basis functions are not necessarily polynomials, and that in general they may not allow reconstruction. It's unclear what exactly the model learns after training, but training seems to empirically help even if the resulting system of bases isn't orthogonal anymore. HTTYH explains some of these interpretations, derives the basis functions for the original S4 matrix (HiPPO-LegS). For example it shows how HiPPO produces "orthogonal SSMs" that allow "online function reconstruction", but general A/B matrices may not.
  2. We do learn the step size in training. This was first mentioned in LSSL (the predecessor to S4) as a way to improve performance.
    • HTTYH defines the interpretation of step size more formally, and your intuition is exactly correct: larger step size prioritizes recent history while smaller step size weights further back. HTTYH defines a notion of "timescale" which is inversely proportional to the step size.
    • I view the step size as something "intrinsic to the model" instead of the data. What I mean is that it affects the context length of the model, and should be set based on how much context you want to capture. (Of course these are related; if the data was discretized at a 100Hz you may want the model to have a context length of 200, whereas if the data was sampled at 10000Hz you might need a much longer context, so the step size might end up being proportional to the data's natural sampling rate.) I usually initialize the step size based on any intuitions about the data (e.g. short sequences should have larger $\Delta$), and then train it, i.e. let the model learn its desired context length.
  3. The convolutional form of S4 can't handle irregular sampling because convolutions are "linear time-invariant" systems. It is possible to calculate it reasonably efficiently purely in recurrent mode (see ongoing efforts like https://github.com/lindermanlab/S5), which can be extended to handle the irregular sampling case
    • Yes, at inference time you can still handle irregular sampling by choosing a different $\Delta$ per time step.
    • Also, note in practice many "irregularly sampled" time series arise from "missing values" time series; in other words time series that were regularly sampled but with missing values. Oftentimes, you can handle these cases simply by using uniformly-sampled inputs with a mask indicating missing values.
  4. I'm not quite sure about the interpretation of the learned layers to be honest. I think it's still intuitively reasonable that the first layers learn shorter-context features. One explanation perhaps is that although a single HiPPO layer can learn long dependencies, it's still a limited form of dependencies (e.g. SSMs are linear) and perhaps a single layer isn't useful. So the model only learns these simple features with more limited context, and as they get progressively more complex through the network, it becomes useful to learn longer dependencies on them.

Overall, these are great questions and your intuitions are quite accurate. Sorry again for taking so long to respond, and hopefully the recent HTTYH preprint helps address a lot of these. Feel free to ask more questions here; I often get questions by email but I think having publicly available responses might be useful to the community.

oezyurty commented 2 years ago

Dear @albertfgu , thank you so much for all your time and effort for clarifying the points I raised!

Overall, you are definitely right that your most-recent HTTYH paper addresses many of these points. Your preprint arrived just two weeks after my post (a great coincidence :) ), and it shows how you are carefully eliminating the potential gaps in your work and taking the next steps accordingly. This makes your work more and more user-friendly with each step, which eventually broadens the applicability of this great work.

It is a great pleasure for me to exchange ideas with the first author of such an inspirational paper! I wish all the success for you and your colleagues.

I will be closing the issue now, and will be posting the new questions (if any) under a new thread. Many thanks!