state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.35k stars 1.13k forks source link

TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given #609

Open saurabh-kataria opened 1 month ago

saurabh-kataria commented 1 month ago

I am unable to use the sample Mamba2 code. Even with following simple code, it fails to do forward pass.

import torch
from mamba_ssm import Mamba2
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(d_model=dim, d_state=64, d_conv=4, expand=2).to('cuda')
y = model(x)

--------------------------------------------------------------------------- 09:19:57 [2/4999] TypeError Traceback (most recent call last) Cell In[5], line 5 3 x = torch.randn(batch, length, dim).to("cuda") 4 model = Mamba2(d_model=dim, d_state=64, d_conv=4, expand=2).to('cuda') ----> 5 y = model(x)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, kwargs) 1734 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(args, kwargs)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, *kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set()

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py:185, in Mamba2.forward(self, u, seqlen, seq_idx, cu_seqlens, inference_params) 183 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) 184 if self.use_mem_eff_path and inference_params is None: --> 185 out = mamba_split_conv1d_scan_combined( 186 zxbcdt, 187 rearrange(self.conv1d.weight, "d 1 w -> d w"), 188 self.conv1d.bias, 189 self.dt_bias, 190 A, 191 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, 192 chunk_size=self.chunk_size, 193 seq_idx=seq_idx, 194 activation=self.activation, 195 rmsnorm_weight=self.norm.weight if self.rmsnorm else None, 196 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6, 197 outproj_weight=self.out_proj.weight, 198 outproj_bias=self.out_proj.bias, 199 headdim=None if self.D_has_hdim else self.headdim, 200 ngroups=self.ngroups, 201 norm_before_gate=self.norm_before_gate, 202 **dt_limit_kwargs, 203 ) 204 if seqlen_og is not None: 205 out = rearrange(out, "b l d -> (b l) d")

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:930, in mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_w eight, outproj_bias, headdim, ngroups, norm_before_gate) 911 def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroup s=1, norm_before_gate=True): 912 """ 913 Argument: 914 zxbcdt: (batch, seqlen, 2 dim + 2 ngroups dstate + nheads) where dim == nheads headdim (...) 928 out: (batch, seqlen, dim) 929 """ --> 930 return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, *kwargs) 572 if not torch._C._are_functorch_transforms_active(): 573 # See NOTE: [functorch vjp and autograd interaction] 574 args = _functorch.utils.unwrap_dead_wrappers(args) --> 575 return super().apply(args, **kwargs) # type: ignore[misc] 577 if not is_setup_ctx_defined: 578 raise RuntimeError( 579 "In order to use an autograd.Function with functorch transforms " 580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 581 "staticmethod. For more details, please see " 582 "https://pytorch.org/docs/main/notes/extending.func.html" 583 )

TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given

saurabh-kataria commented 1 month ago

https://github.com/state-spaces/mamba/pull/608 works BTW

epicfilemcnulty commented 1 month ago

Got the same error with the latest master, can confirm that applying #608 solves the issue.