johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.62k stars 191 forks source link

Is it recurrent? #11

Closed kroggen closed 10 months ago

kroggen commented 10 months ago

Nice implementation!

I thought that Mamba was somewhat recurrent, like keeping an internal state and then outputting one token at a time. But your code shows that for each new output token, the entire sentence must be passed to the model, including the last outputted token.

Is this the only mode of operation?

It also has the l parameter (sequence length). Does it mean it has a maximum sequence length? The paper shows up to 1 million, so I was expecting it to be recurrent and without limit

From the paper:

Table 2: (Induction Heads.) Models are trained on sequence length 2^8 = 256, and tested on increasing sequence lengths of 2^6 = 64 up to 2^20 = 1048576

It is not clear how they do that

kroggen commented 10 months ago

I mean, it is redoing the same computations over and over again

Imagine a sentence (prompt + generation) of ~1024 tokens

This is not good for inference, at least not for CPU

kroggen commented 10 months ago

The paper states that

Fast training and inference: computation and memory scales linearly in sequence length during training, and unrolling the model autoregressively during inference requires only constant time per step since it does not require a cache of previous elements.

Commonly, the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (where the inputs are seen one timestep at a time).

OK, so I can see that this repo only implements the convolutional mode

It would be cool to have a minimal code for the recurrent mode as well

I will check if I can do that, but it appears to be not so easy :flushed:

johnma2006 commented 10 months ago

Overall, the repo was kept deliberately simple and indeed lacks many of the important optimizations in the original repo.

A few comments:

I mean, it is redoing the same computations over and over again

Yeah, during inference time, the same as in RNNs, we should keep track of the latest hidden state and use that to bootstrap each new token's generation. This is how it's done in the original repo. I kept it out to keep things simpler.

It also has the l parameter (sequence length). Does it mean it has a maximum sequence length? The paper shows up to 1 million, so I was expecting it to be recurrent and without limit

l is variable in the same way the batch size b is variable, so there is no limit.

this repo only implements the convolutional mode

Actually, our implementation is recurrent.

Best, John

kroggen commented 10 months ago

The algo on this repo implements the convolutional mode

To be recurrent:

  1. It must keep the internal state (see the inference_params and _decoding_cache here)
  2. The generation step should not send the entire sentence at each new generated token
swfsql commented 9 months ago

Thanks a lot for the implementation!

I also think this appears to be a convolutional mode. The input is the full list of tokens (instead of being just the last one), and there is a result value for each input token (instead of just being the result for the last given token).

kroggen commented 9 months ago

I have a fork of mamba for CPU: mamba-cpu

Check also the recurrent-only branch

albertfgu commented 9 months ago

Mamba has no convolutional mode. During training time, it sees the entire input sentence at once. This is compatible with recurrence. Please read the vast literature on basics of autoregressive sequence modeling.

kroggen commented 9 months ago

@albertfgu

I was using the term from you paper:

After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3).

Commonly, the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (where the inputs are seen one timestep at a time).

And also referring to inference only

So how do you call the implementation that starts here on the forward() function? (the same and only implementation on this minimal repo)

I call recurrent mode the step() function, which stores an internal state between each token. This is clearly recurrent, and we can even store the internal state (conv and ssm)

The method implemented on this repo is slow for inference (needs to pass the entire sequence again to generate each new token) while on recurrent mode we only pass the last chosen token from the sampling

johnma2006 commented 9 months ago

The passage you quoted is from “Section 2: State Space Models”, which is background on the LTI S4, not Mamba.

Mamba is described in “Section 3: Selective State Space Models”, which says

The resulting time-varying SSMs cannot use convolutions.

Another way to think about it is, there aren’t any convolutions in the scan for any standard definition of convolution.

swfsql commented 9 months ago

I understand that the whole mamba block cannot be described as a single convolution (K * u), but the forward implementation has a convolution operation (conv1d) over specifically the sequence and there is a x_dbl, ∆, B, C for each sequence point, whereas there would be no sequence axis for a recurrent-like interface. Yes it's wrong to call the mamba block as a single convolution, but the forward call interface is convolution-like, as in, for inference each new token takes longer to get calculated from the previous one (on different forward calls, with an increasing input sequence).

For example, I'm trying to port the minimal to a library running on the cpu, and the first forward call takes <300ms/token, and after a few dozen tokens, it takes >2000ms/token.

extra specific info edit: in case it's needed, to add more info on this example, in this port I mentioned I'm trying to make, all shapes are potentially compile time constants (for the input, output and all intermediate calculations, including for convolutions - all axis and dimensions can be potentially known at compile-time for all shapes), so from this it's "certain" that the forward call output is a result referring to each sequence point - otherwise I'd be getting a compiler-time error. As another positive, there is no dynamic memory (no garbage collector), so that also couldn't contribute to a loss in performance over time.

But here (johnma2006/mamba-minimal) this appears transparent (only the last output from the last input is received) because all other outputs for each sequence point are discarded after each forward call.

albertfgu commented 9 months ago

This repo is meant to provide a minimal implementation for the forward (training) pass over an entire sequence, which is the harder part and is where the official Mamba implementation has a specialized low-level kernel. The implementation for recurrence is much simpler and the one provided in the official implementation should be very readable already.

johnma2006 commented 9 months ago

Let me try to address some potential misconceptions!

  1. "there would be no sequence axis for a recurrent-like interface" This is not true; for example, please see PyTorch's RNN implementation, which takes in input of shape (sequence length, input size) and an initial hidden state h_0.

Then, during training: inputs of shape (sequence length, input size) and h_0 = None During inference: inputs of shape (1, input_size), and provide the hidden state h_0

Notice that these are both "recurrent", just the latter keeps track of and passes around this hidden state, and not "convolutional vs. recurrent".

  1. Whenever we talk about convolutions in the context of SSMs, it's most likely referring to the nice trick in LTI SSMs where you can express the scan as a convolution: see The Annotated S4 for a tutorial
swfsql commented 9 months ago

I think this is just an optimization, it's not that rnns inherently require all of the sequence axis* as the input in one-go. Isn't the goal of rnns is to not require the sequence axis as this?

*It's perfectly fine to train an rnn with a single sequence point at a time. For optimization training purposes, they can allow to input all of the known sequence (depending on how you progress the loss of your training over the sequence).

Note: I'm not including bi-directional rnns, which would be related to negative deltas (inputs in the reverse order) for mamba, which in this case would require the whole available sequence for each token training and prediction.


edit: I think I'm commenting on the basis that the original question "Is it recurrent?" should be closed with a "No", for practical purposes. That is, if anyone is intending to have a constant-time (commonly implied as a recurrent algorithm), they should still implement (or port) more functionality.

But by no means I disregard the importance of this code, which is also sufficient for training, which is much more important and essential (in my opinion) than a recurrent-like implementation. I appreciate everyone's work and willingness to respond!

jinlovespho commented 2 months ago

Hi, @kroggen I thought the code was in a recurrent fasion? as the hidden state x is consistently updated to produce the output y in line 317. https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L315-L320 Can you help if i got it wrong?