MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.06k stars 123 forks source link

some problems about mamba2 support #237

Closed xiakela closed 3 months ago

xiakela commented 3 months ago

When I try the mamba2 support in selective_scan_backend=='torch', I meet a problem:

Traceback (most recent call last): File "/home/wzw/Vmamba/VMamba/classification/main.py", line 641, in main(config, args) File "/home/wzw/Vmamba/VMamba/classification/main.py", line 204, in main flops = model.flops() File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 2178, in flops Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/flop_count.py", line 147, in flop_count for op, flop in flop_counter.by_operator().items(): File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 265, in by_operator stats = self._analyze() File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 551, in _analyze graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 176, in _get_scoped_tracegraph graph, = _get_trace_graph(module, inputs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/jit/_trace.py", line 1296, in _get_trace_graph outs = ONNXTracedModule( File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/jit/_trace.py", line 138, in forward graph, out = torch._C._create_graph_by_tracing( File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/jit/_trace.py", line 129, in wrapper outs.append(self.inner(trace_inputs)) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl result = forward_call(*args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward result = self.forward(*input, kwargs) File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 2157, in forward x = layer(x) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl result = forward_call(args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward result = self.forward(*input, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl result = forward_call(args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward result = self.forward(*input, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl result = forward_call(args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward result = self.forward(*input, kwargs) File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1238, in forward return self._forward(input) File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1226, in _forward x = x + self.drop_path(self.op(self.norm(x))) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl result = forward_call(args, kwargs) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward result = self.forward(*input, **kwargs) File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1099, in forwardm0 y = self.forward_core(x) File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1068, in forward_corem0 ys, final_state = selective_scan_chunk_fn( File "/home/wzw/Vmamba/VMamba/classification/models/mamba2/ssd_minimal.py", line 135, in selective_scan_chunk_fn return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states) File "/home/wzw/Vmamba/VMamba/classification/models/mamba2/ssd_minimal.py", line 110, in mamba_chunk_scan_combined_torch y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states) File "/home/wzw/Vmamba/VMamba/classification/models/mamba2/ssd_minimal.py", line 55, in ssd_minimal_discrete Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined] RuntimeError: einsum(): subscript h has size 64 for operand 2 which does not broadcast with previously seen size 4

MzeroMiko commented 3 months ago

Sorry, I did not test the torch version. I'll try to fix this later. Can you use the SSD written in triton?

MzeroMiko commented 3 months ago

The bug in torch version has been fixed now.

xiakela commented 3 months ago

Thanks for your reply, the problem has been solved.

Journey7331 commented 3 months ago

When using the latest torch version and AMP, another problem occurs. no this problem when using triton

AMP
...
  File "/home/proj/VMamba/classification/models/mamba2/ssd_minimal.py", line 146, in selective_scan_chunk_fn
    return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
  File "/home/proj/VMamba/classification/models/mamba2/ssd_minimal.py", line 121, in mamba_chunk_scan_combined_torch
    y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states)
  File "/home/proj/VMamba/classification/models/mamba2/ssd_minimal.py", line 44, in ssd_minimal_discrete
    assert X.dtype == A.dtype == B.dtype == C.dtype
AssertionError

Process finished with exit code 1