state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.47k stars 1.05k forks source link

mamab2 has the error #384

Open liangaomng opened 3 months ago

liangaomng commented 3 months ago

thanks for your wanderful work! When I. run the mamba1, it is ok. but when I run the mamba2 in your .readme ,it shows that: File "", line 21, in _chunk_cumsum_fwd_kernel KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-bf290056f5caf73914ee917acc2e7230-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, 'i32', 'i32', 'i32', 'i32', 'fp32', 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (True, True, 1, 256), (True, True, True, True, True, (False, False), (True, False), (False, False), (True, False), (False,), (False,), (True, False), (False, False), (False, True), (False, True), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/liujinxin/lam/le_pde/pytorch_net/mamba.py", line 15, in y = model(x) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 176, in forward out = mamba_split_conv1d_scan_combined( File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 908, in mamba_split_conv1d_scan_combined return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(*args, *kwargs) # type: ignore[misc] File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd return fwd(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 773, in forward outx, , dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 307, in _mamba_chunk_scan_combined_fwd dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 582, in _chunk_cumsum_fwd _chunk_cumsum_fwd_kernel[grid_chunk_cs]( File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in run timings = {config: self._bench(*args, config=config, kwargs) File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in timings = {config: self._bench(*args, config=config, *kwargs) File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 65, in _bench return do_bench(kernel_call) File "/opt/conda/lib/python3.10/site-packages/triton/testing.py", line 143, in do_bench fn() File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call self.fn.run(args, num_warps=config.num_warps, num_stages=config.num_stages, current) File "", line 41, in _chunk_cumsum_fwd_kernel File "/opt/conda/lib/python3.10/site-packages/triton/compiler.py", line 1589, in compile fn_cache_manager = CacheManager(make_hash(fn, **kwargs)) File "/opt/conda/lib/python3.10/site-packages/triton/compiler.py", line 1499, in make_hash key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}" File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 333, in cache_key dependencies_finder.visit(self.parse()) File "/opt/conda/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit self.visit(item) File "/opt/conda/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit self.visit(item) File "/opt/conda/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/opt/conda/lib/python3.10/ast.py", line 428, in generic_visit self.visit(value) File "/opt/conda/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 55, in visit_Call func = self.visit(node.func) File "/opt/conda/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 52, in visit_Attribute return getattr(lhs, node.attr) AttributeError: module 'triton.language' has no attribute 'cumsum'

tridao commented 3 months ago

Please use triton >= 2.1.0