state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.8k stars 1.08k forks source link

Simple inference example #187

Closed miraodasilva closed 7 months ago

miraodasilva commented 7 months ago

Hi, so I'm trying to make an equivalent example to the one presented on the README but for autoregressive generation.

Basically I want to make sure that inference step by step is the same as when I feed in the full sequence. Basically this helps verify that the model is causal and that the way I'm doing the forward pass is equivalent in training and inference.

I took a look at utils/generation.py and tried to play with inference_params but unfortunately couldn't find a way to get y3 to be equivalent to y1 or y2. What am I doing wrong?

@torch.inference_mode()
def run():
    batch, length, dim = 2, 64, 16
    x = torch.randn(batch, length, dim).to("cuda")
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim,  # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,  # Local convolution width
        expand=2,  # Block expansion factor
        layer_idx=0,
    ).to("cuda")

    # Training-style forward pass (full sequence in parallel)
    y1 = model(x)
    assert y1.shape == x.shape

    # Inference-style forward pass (full sequence in parallel)
    infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
    y2 = model(x, inference_params=infer_params)

    # Inference-style forward pass (step by step using for loop)
    infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
    outs = []
    for i in range(length):
        out = model(x[:, i : i + 1, :], inference_params=infer_params)
        infer_params.seqlen_offset += 1
        outs.append(out)
    y3 = torch.cat(outs, 1)

    print(torch.allclose(y1, y2))  # prints True
    print(torch.allclose(y2, y3))  # prints False
    print(torch.allclose(y1, y3))  # prints False

Thanks for reading :)

CompRhys commented 7 months ago

This relates to the ability of the model to be passed an initial state, when passing a fresh sequence to the model it takes a zero initial state and doesn't have a cache for the 1dconv values.

I made a tracker of issues relating to this concept, https://github.com/state-spaces/mamba/issues/175

miraodasilva commented 7 months ago

I'm not sure I understand - as mentioned in https://github.com/state-spaces/mamba/issues/101 , one can feed in the state via InferenceParams during step-by-step inference, no? It's just that you can't do it during training, which is an issue for some people but does not interfere with what I'm trying to do here I think.

I'm sure there must be a way to get three Trues above no? If not, does that mean that there's always a fundamental mismatch between training and inference? Because that sounds problematic.

CompRhys commented 7 months ago

I think I just read what I wanted to read before. You're right that this is a different use case and seemingly you can do it with the InferenceParams when doing inference however for this to work I think you will need to pass the values of the ssm_state and conv_state into the key_value_memory_dict of the InferenceParams.

        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out

In order to do your stepwise loop (y3) using the InferenceParams you would need to duplicate work and call model.step to get the states for the next step in addition to the model call you already have.

miraodasilva commented 7 months ago

Actually, I don't think this is an issue - the states are updated inplace as commented in the code chunk you mentioned. Long story short, I think just feeding in inference params should feed and update the states you are referring to implicitly, and judging by how inference is done on utils/generation.py, this seems to be the way to do it.

To be sure, I decided to roll out what's happening in _get_states_from_cache as such:

@torch.inference_mode()
def run():
    # Training-style forward pass (full sequence in parallel)
    batch, length, dim = 2, 64, 16
    x = torch.randn(batch, length, dim).to("cuda")
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim,  # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,  # Local convolution width
        expand=2,  # Block expansion factor
        layer_idx=0,
    ).to("cuda")
    y1 = model(x)
    assert y1.shape == x.shape

    # Inference-style forward pass (full sequence in parallel)
    infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
    y2 = model(x, inference_params=infer_params)

    # Inference-style forward pass (step by step using for loop)
    infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
    outs = []
    # for i in range(length):
    #     out = model(x[:, i : i + 1, :])
    #     infer_params.seqlen_offset += 1
    #     outs.append(out)
    conv_state = torch.zeros(
        batch,
        dim * 2,
        4,
        device=x.device,
        dtype=x.dtype,
    )
    ssm_state = torch.zeros(
        batch,
        dim * 2,
        16,
        device=x.device,
        dtype=x.dtype,
    )
    for i in range(length):
        out, conv_state, ssm_state = model.step(x[:, i : i + 1, :], conv_state, ssm_state)
        outs.append(out)
    y3 = torch.cat(outs, 1)

    print(torch.allclose(y1, y2))  # prints True
    print(torch.allclose(y2, y3))  # prints False
    print(torch.allclose(y1, y3))  # prints False

Same problem arises. Prints True False False as usual.

Would be super helpful if I could get some feedback here from the authors @tridao - I think this is a pretty basic use case and I'm sure there's a way to make it work but it seems real hard to figure out what exactly I'm doing wrong. I think a simple functional inference example like this would help a lot of users out there. Cheers :)

CompRhys commented 7 months ago
y2.size()
>>> torch.Size([2, 64, 16])
torch.allclose(y3[:, :10, :], y2[:, :10, :])
>>> True
torch.allclose(y3[:, :11, :], y2[:, :11, :])
>>> False
torch.allclose(y3[:, -1, :], y2[:, -1, :])
>>> True
(y2-y3).abs().max()
>>> tensor(5.2154e-08, device='cuda:0')
(y2-y3).abs().var()
>>> tensor(3.0914e-17, device='cuda:0')
(y2-y3).abs().mean()
>>> tensor(5.8863e-09, device='cuda:0')
(y2-y3).abs().median()
>>> tensor(3.7253e-09, device='cuda:0')
((y2-y3).abs()<1e-8).sum()/torch.tensor(y2.size()).prod()
>>> tensor(0.8257, device='cuda:0')

Is it just a numerical tolerance issue then? when I run the above it diverges on the 11th item in the sequence and yet only 17% of the values differ by more than the atol

miraodasilva commented 7 months ago

You're right! Actually, if I just switch to rtol=1e-4 all of them work. I wonder if this is working as intended though, normal rtol=1e-5 should also be True if what we're doing is correct I think - after all, we are running on fp32. I suppose it's good enough to unblock me.

Interestingly though, the rolled out version gets 3 Trues but the non-rolled out version gets True False False, which means I'm actually wrong and they are not equivalent. Do you have a good intuition for why this is the case? Are the states not being modified in place? If so, is the comment there wrong? Appreciate the patience :)

miraodasilva commented 7 months ago

Also, would still be great to get confirmation from the authors that this minimal example is actually well made and that I'm not messing anything up. For peace of mind :)

albertfgu commented 7 months ago

This looks expected. Pure autoregressive models can have small numerical divergence due to compounding errors in the hidden state. In practice generation works fine.

miraodasilva commented 7 months ago

Nevermind, my first example does work with rtol=1e-4, I benchmarked the wrong code earlier. The small divergence accumulated due to compounding errors makes sense. Thanks @CompRhys @CompRhys for the helpful input, closing this now :)

miraodasilva commented 7 months ago

PS: Maybe such an example could be in the README to help others?

PlayerSAL commented 4 months ago

Did you find that if batch size equals to 1 Mamba will broke?

tridao commented 4 months ago

That was fixed a while ago i think

Lyinggg commented 1 month ago

Has anyone tried this example with the Mamba2 model? I found adding rtol=1e-4 to be only useful for Mamba but not for Mamba2 - I still get True, False, False even if I set rtol=1e-2. Is there anything wrong with inference_params in Mamba2?