File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, *kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py:185, in Mamba2.forward(self, u, seqlen, seq_idx, cu_seqlens, inference_params)
183 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
184 if self.use_mem_eff_path and inference_params is None:
--> 185 out = mamba_split_conv1d_scan_combined(
186 zxbcdt,
187 rearrange(self.conv1d.weight, "d 1 w -> d w"),
188 self.conv1d.bias,
189 self.dt_bias,
190 A,
191 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
192 chunk_size=self.chunk_size,
193 seq_idx=seq_idx,
194 activation=self.activation,
195 rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
196 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
197 outproj_weight=self.out_proj.weight,
198 outproj_bias=self.out_proj.bias,
199 headdim=None if self.D_has_hdim else self.headdim,
200 ngroups=self.ngroups,
201 norm_before_gate=self.norm_before_gate,
202 **dt_limit_kwargs,
203 )
204 if seqlen_og is not None:
205 out = rearrange(out, "b l d -> (b l) d")
File /scratch/skataria/anaconda3/envs/tmp5/lib/python3.10/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, *kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )
TypeError: custom_fwd() takes from 0 to 1 positional arguments but 21 positional arguments (and 1 keyword-only argument) were given
I am unable to use the sample Mamba2 code. Even with following simple code, it fails to do forward pass.