alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

How to use cache in mamba2? #52

Closed wwwqqyy closed 2 months ago

wwwqqyy commented 2 months ago

Thank you for providing such excellent code, I have a question for you, how is the cache of step function used in the following code block? What does it do? Can you give an example of how it works? Thank you.

class Mamba2(nn.Module):

def __init__(self, config: Mamba2Config):
    super().__init__()
    self.config = config
    self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])

def forward(self, x):
    # x : (B, L, D)

    # y : (B, L, D)

    for layer in self.layers:
        x = layer(x)

    return x

def step(self, x, caches):
    # x : (B, L, D)
    # caches : [cache(layer) for all layers], cache : (h, inputs)

    # y : (B, L, D)
    # caches : [cache(layer) for all layers], cache : (h, inputs)

    for i, layer in enumerate(self.layers):
        x, caches[i] = layer.step(x, caches[i])

    return x, caches
alxndrTL commented 2 months ago

Thanks! I have updated the mamba2.py code, it now implements a proper caching mecanism. I have put an explanation of how it works at the top of the file, hope it helps you. If you still have question, don't hesitate to ask!

For how to use it, you can take a look at the lm.py file which encapsulates a Mamba(2) object into a language model. Precisely, you can take a look at the generate function (which decodes from the model but uses no caching) and the generate4 function which decodes from the model with prompt prefill + step by step decoding (so with caching).

wwwqqyy commented 2 months ago

Thank you very much, I have seen your recent update!