Closed xiakela closed 3 months ago
Sorry, I did not test the torch version. I'll try to fix this later. Can you use the SSD written in triton?
The bug in torch version has been fixed now.
Thanks for your reply, the problem has been solved.
When using the latest torch version and AMP, another problem occurs. no this problem when using triton
AMP
...
File "/home/proj/VMamba/classification/models/mamba2/ssd_minimal.py", line 146, in selective_scan_chunk_fn
return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
File "/home/proj/VMamba/classification/models/mamba2/ssd_minimal.py", line 121, in mamba_chunk_scan_combined_torch
y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states)
File "/home/proj/VMamba/classification/models/mamba2/ssd_minimal.py", line 44, in ssd_minimal_discrete
assert X.dtype == A.dtype == B.dtype == C.dtype
AssertionError
Process finished with exit code 1
When I try the mamba2 support in selective_scan_backend=='torch', I meet a problem:
Traceback (most recent call last): File "/home/wzw/Vmamba/VMamba/classification/main.py", line 641, in
main(config, args)
File "/home/wzw/Vmamba/VMamba/classification/main.py", line 204, in main
flops = model.flops()
File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 2178, in flops
Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/flop_count.py", line 147, in flop_count
for op, flop in flop_counter.by_operator().items():
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 265, in by_operator
stats = self._analyze()
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 551, in _analyze
graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 176, in _get_scoped_tracegraph
graph, = _get_trace_graph(module, inputs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/jit/_trace.py", line 1296, in _get_trace_graph
outs = ONNXTracedModule(
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/jit/_trace.py", line 138, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/jit/_trace.py", line 129, in wrapper
outs.append(self.inner(trace_inputs))
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, kwargs)
File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 2157, in forward
x = layer(x)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, kwargs)
File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1238, in forward
return self._forward(input)
File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1226, in _forward
x = x + self.drop_path(self.op(self.norm(x)))
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(args, kwargs)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1099, in forwardm0
y = self.forward_core(x)
File "/home/wzw/Vmamba/VMamba/classification/models/vmamba.py", line 1068, in forward_corem0
ys, final_state = selective_scan_chunk_fn(
File "/home/wzw/Vmamba/VMamba/classification/models/mamba2/ssd_minimal.py", line 135, in selective_scan_chunk_fn
return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
File "/home/wzw/Vmamba/VMamba/classification/models/mamba2/ssd_minimal.py", line 110, in mamba_chunk_scan_combined_torch
y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states)
File "/home/wzw/Vmamba/VMamba/classification/models/mamba2/ssd_minimal.py", line 55, in ssd_minimal_discrete
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
File "/home/wzw/miniconda3/envs/vmamba2/lib/python3.10/site-packages/torch/functional.py", line 380, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): subscript h has size 64 for operand 2 which does not broadcast with previously seen size 4