state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.55k stars 1.05k forks source link

Mamba2 9 times slower inference time than Mamba1 #355

Closed realwenlongwang closed 3 months ago

realwenlongwang commented 3 months ago

After change the d_model, mamba2 worked in the simple test environment provided in the README. But I noticed that the mamba2 has a much slower speed than mamba1. I tested it, here is my code

import torch
from mamba_ssm import Mamba2 as Mamba
# from mamba_ssm import Mamba

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
end.record()
torch.cuda.synchronize()
inference_time = start.elapsed_time(end)
assert y.shape == x.shape
print(f'parameter number: {sum([p.numel() for p in model.parameters()])}')
print(f'inference time: {inference_time}')

The result I got is this

Mamba1 parameter number: 511488
Mamba1 inference time: 539.1769409179688
Mamba2 parameter number: 431768
Mamba2 inference time: 4322.52294921875

I don't know if it is a bug or did I make a mistake. Please feel free to share your thoughts.

Kiet0712 commented 3 months ago

I think there are some mistake in code because i also found that mamba2 is quite slow :)

tridao commented 3 months ago

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Kiet0712 commented 3 months ago

@tridao Thank you for your help. I add this line "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function in ssd_combined and get the speed competitive with original mamba code.

realwenlongwang commented 3 months ago

Yes, CUDA grapha works!

dwgan commented 3 months ago

Hi @Kiet0712, could you please provide some details about it? Here is my code

@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
    """
    Argument:
        zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
        conv1d_weight: (dim + 2 * ngroups * dstate, width)
        conv1d_bias: (dim + 2 * ngroups * dstate,)
        dt_bias: (nheads,)
        A: (nheads)
        D: (nheads, headdim) or (nheads,)
        initial_states: (batch, nheads, headdim, dstate)
        seq_idx: (batch, seqlen), int32
        rmsnorm_weight: (dim,)
        outproj_weight: (out_dim, dim)
        outproj_bias: (out_dim,)
        headdim: if D is 1D, headdim must be passed in
        norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
    Return:
        out: (batch, seqlen, dim)
    """
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

but it doesn't work. the output is

/home/anaconda3/envs/mamba/bin/python /home/mamba/main_test.py 
Mamba model parameters: 511488
Mamba2 model parameters: 433840

Mamba model time: 109.65401458740234 ms
torch.Size([2, 64, 256])
[2024-06-06 01:39:06,033] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-06-06 01:39:06,034] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'tuple'>
Traceback (most recent call last):
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 136, in speculate_subgraph
    args = validate_args_and_maybe_create_graph_inputs(
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 94, in validate_args_and_maybe_create_graph_inputs
    raise unimplemented(
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'tuple'>

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/mamba/main_test.py", line 58, in <module>
    y = model2(x)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mamba/mamba_ssm/modules/mamba2.py", line 176, in forward
    out = mamba_split_conv1d_scan_combined(
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
    super().run()
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
    and self.step()
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
    getattr(self, inst.opname)(inst)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1110, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 557, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 583, in call_function
    return self.obj.call_apply(tx, args, kwargs).add_options(self)
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 333, in call_apply
    speculated_fwd_result = higher_order_autograd_fn.call_function(
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 966, in call_function
    ) = speculate_subgraph(
  File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 211, in speculate_subgraph
    raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 908, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

It is some bug or my mistakes?

Thanks for your kindly response.

Kiet0712 commented 3 months ago

What is your main_test.py file and can you give me some detail about your environment ?

dwgan commented 3 months ago

@Kiet0712 Thanks for your response. It is the main_test.py

import torch
from mamba_ssm import Mamba, Mamba2

batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")

# Initialize the Mamba model
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

# Initialize the Mamba2 model
model2 = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,  # Additional parameter for Mamba2
    ngroups=1,   # Number of groups for group normalization
    sequence_parallel=False, # Whether to use sequence parallelism
).to("cuda")

# Function to count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# Print the number of parameters for each model
print(f"Mamba model parameters: {count_parameters(model)}")
print(f"Mamba2 model parameters: {count_parameters(model2)}")

# Measure inference time for Mamba model
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
y = model(x)
end_event.record()

# Wait for all CUDA operations to finish
torch.cuda.synchronize()

mamba_time = start_event.elapsed_time(end_event) # Time in milliseconds

print(f"\nMamba model time: {mamba_time} ms")
print(y.shape)
assert y.shape == x.shape

# Measure inference time for Mamba2 model
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
y = model2(x)
end_event.record()

# Wait for all CUDA operations to finish
torch.cuda.synchronize()

mamba2_time = start_event.elapsed_time(end_event) # Time in milliseconds

print(f"\nMamba2 model time: {mamba2_time} ms")
print(y.shape)
assert y.shape == x.shape
arelkeselbri commented 3 months ago

I still find this issue even with cuda graphs compile.

I applied ".contiguous()" patch to fix stride issues. Also used annotation for compile with CUDA graphs.

My test is on a H100 with:

import torch
import timeit
from mamba_ssm import Mamba, Mamba2

batch, length, dim = 2, 64, 64
x = torch.randn(batch, length, dim).to("cuda")

def try_mamba1(batch, length, dim, x):
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to("cuda")
    y = model(x)
    assert y.shape == x.shape

def try_mamba2(batch, length, dim, x):
    model = Mamba2(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=64,  # SSM state expansion factor, typically 64 or 128
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to("cuda")
    y = model(x)
    assert y.shape == x.shape

mamba1_time = timeit.timeit('try_mamba1(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 1 took {mamba1_time} seconds")

mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 2 took {mamba2_time} seconds")

Package versions:

torch 2.3.0 causal-conv1d 1.2.2.post1 mamba-ssm 2.0.3 nvidia-cuda-runtime-cu12 12.1.105

LOG:

Traceback (most recent call last):
  File "/home/albertini/mamba/test2.py", line 33, in <module>
    mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/timeit.py", line 237, in timeit
    return Timer(stmt, setup, timer, globals).timeit(number)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/timeit.py", line 180, in timeit
    timing = self.inner(it, self.timer)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<timeit-src>", line 6, in inner
  File "/home/albertini/mamba/test2.py", line 27, in try_mamba2
    y = model(x)
        ^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/mamba_ssm/modules/mamba2.py", line 176, in forward
    out = mamba_split_conv1d_scan_combined(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/usr/local/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
        ^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 562, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 420, in call_method
    return self.call_apply(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 368, in call_apply
    ).call_function(tx, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1533, in call_function
    (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
                                            ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 467, in speculate_subgraph
    raise ex
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 359, in speculate_subgraph
    args = validate_args_and_maybe_create_graph_inputs(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 200, in validate_args_and_maybe_create_graph_inputs
    raise unimplemented(
          ^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/albertini/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 909, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
hengck23 commented 3 months ago

for my case:

[added this code]
ssd_combined.py
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
[error]
LLVM ERROR: pthread_join failed: Invalid argument
LLVM ERROR: pthread_join failed: Invalid argument
...

    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'skip function PyCapsule.causal_conv1d_fwd in file Builtin causal_conv1d_fwd'
...

  File "/home/hp/app/anaconda3.10/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward
    causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
Kiet0712 commented 3 months ago

@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

DustinEwan commented 3 months ago
  File "/home/hp/app/anaconda3.10/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward
    causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),

I have the same problem, trying to work through it now. If I find a solution I'll let you know, in the meantime any help is very much appreciated!

AlwaysFHao commented 3 months ago

@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

May I ask about your Torch and Triton versions?

dwgan commented 3 months ago

@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

May I ask about your Torch and Triton versions?

Torch==2.1.2 Triton==2.1.0 python==3.10 ubuntu18.04

Baijiong-Lin commented 3 months ago

I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

@Kiet0712 could you tell us your torch and triton versions? thanks.

Kiet0712 commented 3 months ago

@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1

Baijiong-Lin commented 3 months ago

@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1

@Kiet0712 thanks. but it does not work for me. it still has an error after adding "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function.

TimothyChen225 commented 2 months ago

do anyone solved

TimothyChen225 commented 2 months ago

Yes, CUDA grapha works!

I've tried this but I'm still getting an error, and I'd appreciate it if you could show me the demo code

dwgan commented 2 months ago

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.

I think the problem was solved.

See my code here

import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba

repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")

The output log

Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s
AlwaysFHao commented 2 months ago

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.

I think the problem was solved.

See my code here

import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba

repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")

The output log

Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s

After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here?

dwgan commented 2 months ago

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.

I think the problem was solved. See my code here

import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba

repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")

The output log

Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s

After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here?

I use the original version, without adding 'torch. compile'.

TimothyChen225 commented 2 months ago

see #389