Open Seeker98 opened 5 months ago
Tried the following and time seems not to change. Maybe this is just an initial delay:
for i in range(10):
x = torch.randn(batch, length, dim).to("cuda")
y = model2(x)
Well I’m wondering about why adding compile as #355 discussion makes the code failed, as the author mentioned this could accelerate a lot
the same issue
同样的问题
the same issue
如 #355,我在文件“ssd_combined.py”中的“mamba_chunk_scan_combined”函数中添加了“@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)”,运行失败,错误如下:
Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'> from user code: File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True
重现代码:
import torch from mamba_ssm import Mamba2 batch, length, dim = 8,1024,128 x = torch.randn(batch, length, dim).to("cuda") model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor headdim=32, use_mem_eff_path=False ).to("cuda") y = model(x) assert y.shape == x.shape
我不确定要提供什么,但我的包是: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 和 py39_cu121_cudnn8.9.2_0
Hi, I have the same problem, have you solved it?
the same issue
Same here
Same here...
same issue here
same issue here,do anyone solve it, very thanks! 同样的问题
as #355 , I added "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" to "mamba_chunk_scan_combined" function in file "ssd_combined.py", and running failed with error:
reproduce code:
I'm not sure what to provide, but my packages are: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 with py39_cu121_cudnn8.9.2_0