state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.18k stars 1.12k forks source link

Chunked inference #536

Open yhv-wt opened 2 months ago

yhv-wt commented 2 months ago

Hi,

First thank you for your incredible work!

From what I can tell from the code there are two paths for inference - either by sending the complete input sequence using the fast path when inference_params is None or much slower sending one input (token or whatever) at a time, when inference_params.seqlen_offset > 0. This design mostly caters for chats, i.e. process prompt fast using triton and then generate response slowly token-by-token. But not every AI problem is a chat. It's still unclear to me what the third path is for, i.e. when inference_params is not None but inference_params.seqlen_offset == 0 but doesn't seem it's for what I'm asking. So my question is there any way to implement fast chunked inference of potentially very long sequences (that otherwise won't fit into memory) by passing final SSM and conv states between chunks and still using fast triton kernels path or is there some fundamental restriction of this architecture that only allows continued inference one token at a time?

Thanks!

OleguerCanal commented 2 months ago

I also have the same doubt. Just to clarify:

We'd like step function in the Mamba modules to be able to take t > 1 tokens. As of now there is this assert:

assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"

Are there plans to address this case? Is it very hard to do @tridao? Thanks!

assafbk commented 1 week ago

+1, it will be great if chunked parallel mode will be supported (even if just for inference).

Just to clarify a few things: @yhv-wt the first path (inference_params is None) is for training in parallel mode. The second path is for inference in sequential mode. The third path is for inference in parallel mode. The two latter paths are an efficient way to perform inference - process the context in parallel mode and then predict new tokens in recurrent mode (the first hidden state for new token prediction is the last hidden state of the context).

@OleguerCanal Not sure if I fully understand your request - but sequential inference mode ('step') predicts one token at a time. Since each token depends on the previous tokens, I dont see how this can be sped up.

OleguerCanal commented 1 week ago

@assafbk only if you are running a model in an auto-regressive manner (like in next-token prediction applications such as language models). Some applications that do sequence-to-sequence have available all input tokens from the beginning even at inference time. It is just too much memory to run it all at once, that's why I want to chunk it.

Not sure I am expressing myself correctly.

assafbk commented 1 week ago

Right. I think that we are talking about the same thing. I just highlighted that what you are asking for is not possible via 'step' (the third path). It is possible via the second path (inference in parallel mode).

843773493 commented 1 week ago

May be like this?

import easydict
batch, length, dim, chunk_size = 2, 64, 64, 4
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=dim//8,
    layer_idx='0', # Layer index for inference cache
).to("cuda")
inference_params = easydict.EasyDict()
inference_params.key_value_memory_dict = {'0':model.allocate_inference_cache(batch, max_seqlen=None)}
inference_params.seqlen_offset = 0
for i in range(chunk_size):
    y_ = model(x[:, i*(length//chunk_size):(i+1)*(length//chunk_size)], inference_params=inference_params)
    print(
        y_.mean().item(), 
        inference_params.key_value_memory_dict['0'][0].shape,
        inference_params.key_value_memory_dict['0'][0].mean().item(), 
        inference_params.key_value_memory_dict['0'][1].shape,
        inference_params.key_value_memory_dict['0'][1].mean().item(),
    )