Closed realwenlongwang closed 3 months ago
I think there are some mistake in code because i also found that mamba2 is quite slow :)
Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.
@tridao Thank you for your help. I add this line "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function in ssd_combined and get the speed competitive with original mamba code.
Yes, CUDA grapha works!
Hi @Kiet0712, could you please provide some details about it? Here is my code
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
"""
Argument:
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
conv1d_weight: (dim + 2 * ngroups * dstate, width)
conv1d_bias: (dim + 2 * ngroups * dstate,)
dt_bias: (nheads,)
A: (nheads)
D: (nheads, headdim) or (nheads,)
initial_states: (batch, nheads, headdim, dstate)
seq_idx: (batch, seqlen), int32
rmsnorm_weight: (dim,)
outproj_weight: (out_dim, dim)
outproj_bias: (out_dim,)
headdim: if D is 1D, headdim must be passed in
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
Return:
out: (batch, seqlen, dim)
"""
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
but it doesn't work. the output is
/home/anaconda3/envs/mamba/bin/python /home/mamba/main_test.py
Mamba model parameters: 511488
Mamba2 model parameters: 433840
Mamba model time: 109.65401458740234 ms
torch.Size([2, 64, 256])
[2024-06-06 01:39:06,033] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-06-06 01:39:06,034] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'tuple'>
Traceback (most recent call last):
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 136, in speculate_subgraph
args = validate_args_and_maybe_create_graph_inputs(
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 94, in validate_args_and_maybe_create_graph_inputs
raise unimplemented(
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'tuple'>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/mamba/main_test.py", line 58, in <module>
y = model2(x)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/mamba/mamba_ssm/modules/mamba2.py", line 176, in forward
out = mamba_split_conv1d_scan_combined(
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
tracer.run()
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
super().run()
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
and self.step()
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
getattr(self, inst.opname)(inst)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1110, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 557, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 583, in call_function
return self.obj.call_apply(tx, args, kwargs).add_options(self)
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 333, in call_apply
speculated_fwd_result = higher_order_autograd_fn.call_function(
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 966, in call_function
) = speculate_subgraph(
File "/home/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 211, in speculate_subgraph
raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'tuple'>
from user code:
File "/home/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 908, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
It is some bug or my mistakes?
Thanks for your kindly response.
What is your main_test.py file and can you give me some detail about your environment ?
@Kiet0712 Thanks for your response. It is the main_test.py
import torch
from mamba_ssm import Mamba, Mamba2
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
# Initialize the Mamba model
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
# Initialize the Mamba2 model
model2 = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=32, # Additional parameter for Mamba2
ngroups=1, # Number of groups for group normalization
sequence_parallel=False, # Whether to use sequence parallelism
).to("cuda")
# Function to count model parameters
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
# Print the number of parameters for each model
print(f"Mamba model parameters: {count_parameters(model)}")
print(f"Mamba2 model parameters: {count_parameters(model2)}")
# Measure inference time for Mamba model
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
y = model(x)
end_event.record()
# Wait for all CUDA operations to finish
torch.cuda.synchronize()
mamba_time = start_event.elapsed_time(end_event) # Time in milliseconds
print(f"\nMamba model time: {mamba_time} ms")
print(y.shape)
assert y.shape == x.shape
# Measure inference time for Mamba2 model
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
y = model2(x)
end_event.record()
# Wait for all CUDA operations to finish
torch.cuda.synchronize()
mamba2_time = start_event.elapsed_time(end_event) # Time in milliseconds
print(f"\nMamba2 model time: {mamba2_time} ms")
print(y.shape)
assert y.shape == x.shape
I still find this issue even with cuda graphs compile.
I applied ".contiguous()" patch to fix stride issues. Also used annotation for compile with CUDA graphs.
My test is on a H100 with:
import torch
import timeit
from mamba_ssm import Mamba, Mamba2
batch, length, dim = 2, 64, 64
x = torch.randn(batch, length, dim).to("cuda")
def try_mamba1(batch, length, dim, x):
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
def try_mamba2(batch, length, dim, x):
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
mamba1_time = timeit.timeit('try_mamba1(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 1 took {mamba1_time} seconds")
mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 2 took {mamba2_time} seconds")
Package versions:
torch 2.3.0 causal-conv1d 1.2.2.post1 mamba-ssm 2.0.3 nvidia-cuda-runtime-cu12 12.1.105
LOG:
Traceback (most recent call last):
File "/home/albertini/mamba/test2.py", line 33, in <module>
mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/timeit.py", line 237, in timeit
return Timer(stmt, setup, timer, globals).timeit(number)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/timeit.py", line 180, in timeit
timing = self.inner(it, self.timer)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<timeit-src>", line 6, in inner
File "/home/albertini/mamba/test2.py", line 27, in try_mamba2
y = model(x)
^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/mamba_ssm/modules/mamba2.py", line 176, in forward
out = mamba_split_conv1d_scan_combined(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
^^^^^^^^^
File "/usr/local/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
self.call_function(fn, args, kwargs)
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 562, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 420, in call_method
return self.call_apply(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 368, in call_apply
).call_function(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1533, in call_function
(fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 467, in speculate_subgraph
raise ex
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 359, in speculate_subgraph
args = validate_args_and_maybe_create_graph_inputs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 200, in validate_args_and_maybe_create_graph_inputs
raise unimplemented(
^^^^^^^^^^^^^^
File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>
from user code:
File "/home/albertini/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 909, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
for my case:
[added this code]
ssd_combined.py
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
[error]
LLVM ERROR: pthread_join failed: Invalid argument
LLVM ERROR: pthread_join failed: Invalid argument
...
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'skip function PyCapsule.causal_conv1d_fwd in file Builtin causal_conv1d_fwd'
...
File "/home/hp/app/anaconda3.10/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward
causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.
File "/home/hp/app/anaconda3.10/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
I have the same problem, trying to work through it now. If I find a solution I'll let you know, in the meantime any help is very much appreciated!
@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.
May I ask about your Torch and Triton versions?
@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.
May I ask about your Torch and Triton versions?
Torch==2.1.2 Triton==2.1.0 python==3.10 ubuntu18.04
I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.
@Kiet0712 could you tell us your torch and triton versions? thanks.
@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1
@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1
@Kiet0712 thanks. but it does not work for me. it still has an error after adding "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function.
do anyone solved
Yes, CUDA grapha works!
I've tried this but I'm still getting an error, and I'd appreciate it if you could show me the demo code
Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.
Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.
I think the problem was solved.
See my code here
import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba
repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")
The output log
Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s
Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.
Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.
I think the problem was solved.
See my code here
import time import torch from mamba_ssm import Mamba2 from mamba_ssm import Mamba repeat_num = 1000 batch, length, dim = 2, 256*256*2, 256 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) # warm up first t1 = time.time() for i in range(repeat_num): y = model(x) assert y.shape == x.shape print(f"Time of mamba taken: {time.time() - t1:.3f} s") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) # warm up first t1 = time.time() for i in range(repeat_num): y = model(x) assert y.shape == x.shape print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")
The output log
Time of mamba taken: 24.061 s Time of mamba2 taken: 14.011 s
After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here?
Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.
Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.
I think the problem was solved. See my code here
import time import torch from mamba_ssm import Mamba2 from mamba_ssm import Mamba repeat_num = 1000 batch, length, dim = 2, 256*256*2, 256 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) # warm up first t1 = time.time() for i in range(repeat_num): y = model(x) assert y.shape == x.shape print(f"Time of mamba taken: {time.time() - t1:.3f} s") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) # warm up first t1 = time.time() for i in range(repeat_num): y = model(x) assert y.shape == x.shape print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")
The output log
Time of mamba taken: 24.061 s Time of mamba2 taken: 14.011 s
After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here?
I use the original version, without adding 'torch. compile'.
see #389
After change the
d_model
, mamba2 worked in the simple test environment provided in the README. But I noticed that the mamba2 has a much slower speed than mamba1. I tested it, here is my codeThe result I got is this
I don't know if it is a bug or did I make a mistake. Please feel free to share your thoughts.