Closed kroggen closed 10 months ago
I mean, it is redoing the same computations over and over again
Imagine a sentence (prompt + generation) of ~1024 tokens
This is not good for inference, at least not for CPU
The paper states that
Fast training and inference: computation and memory scales linearly in sequence length during training, and unrolling the model autoregressively during inference requires only constant time per step since it does not require a cache of previous elements.
Commonly, the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (where the inputs are seen one timestep at a time).
OK, so I can see that this repo only implements the convolutional mode
It would be cool to have a minimal code for the recurrent mode as well
I will check if I can do that, but it appears to be not so easy :flushed:
Overall, the repo was kept deliberately simple and indeed lacks many of the important optimizations in the original repo.
A few comments:
I mean, it is redoing the same computations over and over again
Yeah, during inference time, the same as in RNNs, we should keep track of the latest hidden state and use that to bootstrap each new token's generation. This is how it's done in the original repo. I kept it out to keep things simpler.
It also has the
l
parameter (sequence length). Does it mean it has a maximum sequence length? The paper shows up to 1 million, so I was expecting it to be recurrent and without limit
l
is variable in the same way the batch size b
is variable, so there is no limit.
this repo only implements the convolutional mode
Actually, our implementation is recurrent.
Best, John
The algo on this repo implements the convolutional mode
To be recurrent:
inference_params
and _decoding_cache
here)Thanks a lot for the implementation!
I also think this appears to be a convolutional mode. The input is the full list of tokens (instead of being just the last one), and there is a result value for each input token (instead of just being the result for the last given token).
I have a fork of mamba for CPU: mamba-cpu
Check also the recurrent-only
branch
Mamba has no convolutional mode. During training time, it sees the entire input sentence at once. This is compatible with recurrence. Please read the vast literature on basics of autoregressive sequence modeling.
@albertfgu
I was using the term from you paper:
After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3).
Commonly, the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (where the inputs are seen one timestep at a time).
And also referring to inference only
So how do you call the implementation that starts here on the forward()
function? (the same and only implementation on this minimal repo)
I call recurrent mode the step() function, which stores an internal state between each token. This is clearly recurrent, and we can even store the internal state (conv and ssm)
The method implemented on this repo is slow for inference (needs to pass the entire sequence again to generate each new token) while on recurrent mode we only pass the last chosen token from the sampling
The passage you quoted is from “Section 2: State Space Models”, which is background on the LTI S4, not Mamba.
Mamba is described in “Section 3: Selective State Space Models”, which says
The resulting time-varying SSMs cannot use convolutions.
Another way to think about it is, there aren’t any convolutions in the scan for any standard definition of convolution.
I understand that the whole mamba block cannot be described as a single convolution (K * u), but the forward implementation has a convolution operation (conv1d) over specifically the sequence and there is a x_dbl, ∆, B, C for each sequence point, whereas there would be no sequence axis for a recurrent-like interface. Yes it's wrong to call the mamba block as a single convolution, but the forward call interface is convolution-like, as in, for inference each new token takes longer to get calculated from the previous one (on different forward calls, with an increasing input sequence).
For example, I'm trying to port the minimal to a library running on the cpu, and the first forward call takes <300ms/token, and after a few dozen tokens, it takes >2000ms/token.
But here (johnma2006/mamba-minimal) this appears transparent (only the last output from the last input is received) because all other outputs for each sequence point are discarded after each forward call.
This repo is meant to provide a minimal implementation for the forward (training) pass over an entire sequence, which is the harder part and is where the official Mamba implementation has a specialized low-level kernel. The implementation for recurrence is much simpler and the one provided in the official implementation should be very readable already.
Let me try to address some potential misconceptions!
there would be no sequence axis for a recurrent-like interface
" This is not true; for example, please see PyTorch's RNN implementation, which takes in input of shape (sequence length, input size)
and an initial hidden state h_0
. Then, during training: inputs of shape (sequence length, input size)
and h_0 = None
During inference: inputs of shape (1, input_size)
, and provide the hidden state h_0
Notice that these are both "recurrent", just the latter keeps track of and passes around this hidden state, and not "convolutional vs. recurrent".
I think this is just an optimization, it's not that rnns inherently require all of the sequence axis* as the input in one-go. Isn't the goal of rnns is to not require the sequence axis as this?
*It's perfectly fine to train an rnn with a single sequence point at a time. For optimization training purposes, they can allow to input all of the known sequence (depending on how you progress the loss of your training over the sequence).
Note: I'm not including bi-directional rnns, which would be related to negative deltas (inputs in the reverse order) for mamba, which in this case would require the whole available sequence for each token training and prediction.
edit: I think I'm commenting on the basis that the original question "Is it recurrent?" should be closed with a "No", for practical purposes. That is, if anyone is intending to have a constant-time (commonly implied as a recurrent algorithm), they should still implement (or port) more functionality.
But by no means I disregard the importance of this code, which is also sufficient for training, which is much more important and essential (in my opinion) than a recurrent-like implementation. I appreciate everyone's work and willingness to respond!
Hi, @kroggen I thought the code was in a recurrent fasion? as the hidden state x is consistently updated to produce the output y in line 317. https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L315-L320 Can you help if i got it wrong?
Nice implementation!
I thought that Mamba was somewhat recurrent, like keeping an internal state and then outputting one token at a time. But your code shows that for each new output token, the entire sentence must be passed to the model, including the last outputted token.
Is this the only mode of operation?
It also has the
l
parameter (sequence length). Does it mean it has a maximum sequence length? The paper shows up to 1 million, so I was expecting it to be recurrent and without limitFrom the paper:
It is not clear how they do that