srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
468 stars 60 forks source link

Audio spoken digit generation and other S4 questions #67

Closed peterfeifanchen closed 1 year ago

peterfeifanchen commented 1 year ago

Hi! Like everyone else on here, I would like to first say, thank you very much for putting this all together.

I was looking at the unconditioned audio generation and have a few question about sequence generation specifically and conceptually:

  1. https://srush.github.io/annotated-s4/#experiments-spoken-digits, it seems the embedded links are broken?
  2. I was wondering for the spoken digit generation, is it conditioned on a sequence length of samples like MNIST completion task is given the first 300 samples? Or is it generated from scratch? I am curious if it's possible to generate condition samples (e.g., ask it to generate audio for a given number and speaker id)?
  3. Also looking through the annotated-S4 post, it mentions that

Note as well that in the original paper Lambda, P, Q are also learned. However, in this post, we leave them fixed for simplicity.

But looking at the code:

` def setup(self):

Learned Parameters (C is complex!)

    init_A_re, init_A_im, init_P, init_B = hippo_initializer(self.N)
    self.Lambda_re = self.param("Lambda_re", init_A_re, (self.N,))
    self.Lambda_im = self.param("Lambda_im", init_A_im, (self.N,))
    # Ensure the real part of Lambda is negative
    # (described in the SaShiMi follow-up to S4)
    self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
    self.P = self.param("P", init_P, (self.N,))
    self.B = self.param("B", init_B, (self.N,))

` It seems Lambda, P, (I guess P=Q here) are declared as params and thus would not optax still optimize them? I am new to flax/jax, so I am wondering how they are been fixed?

  1. For my general understanding. It seems the matrix A is initialized with the HiPPO matrix rather than a random initialization, but is it still suppose to be updated during training like the B, C, D matrices?
  2. Finally, I was wondering intuitively, how does the HiPPO matrix initialization change? It seems from the presentations made by Albert Gu, that it is a result of a chosen measure and the HiPPO matrix in the code here is based off of EMA? Is that correct higher level understanding (e.g., if my measure is a window, I would need a different HiPPO matrix) or is the HiPPO matrix initialization completely general?
albertfgu commented 1 year ago

Hi Peter, I'll reply to some of these questions.

  1. Looks like @srush took care of it!
  2. In the Sashimi paper it was unconditional. I believe it's the same in this codebase although I haven't run any of the experiments to check. Generating conditioned on extra information (e.g. number or speaker id) is orthogonal to generating conditioned on a prefix (e.g. the first 300 numbers of the sequence). There are standard ways to do this, such as passing an embedding of the label (e.g. the number) to every layer of the model
  3. I think the original version had non-learnable P/Q. I later updated this in a fork which was merged in https://github.com/srush/annotated-s4/pull/60 and had several improvements to the original implementation. Some of the description to the original post was probably not updated and I'll go through and update them in the next few weeks.
  4. Yes, the A/B matrices are initialized to fix values and still trained during training (at a lower LR compared to the rest of the network).
  5. HiPPO is a more complicated framework. Couple of points corresponding to some of your questions:

The way I think about training is like this:

TL;DR:

  1. Theoretically, for fixed $(A, B)$ matrices, HiPPO theory gives matrices with very strong performance compared to others
  2. Empirically, nothing is lost by continuing to train the matrices, although they might not be "HiPPO" anymore. The HiPPO initialization is still often better than other initializations, although the gap closes after training
peterfeifanchen commented 1 year ago

Thanks for the prompt response Albert!

A quick follow up on 2), could you point me to where in the Sashimi code the set up for the unconditioned generation happens? And for clarification, the generation was not conditioned even on a prefix and the loss was maximizing the probability of p(x_t|x<t) ?

albertfgu commented 1 year ago

Yes to both. The generation script is here: https://github.com/HazyResearch/state-spaces/blob/main/generate.py Instructions are in the main READMEs which provide the commands for Sashimi generation. I think you can specify the prefix length, i.e. l_prefix=0 means unconditional generation.

peterfeifanchen commented 1 year ago

I see thank you!