MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
1.82k stars 98 forks source link

About the inference error #239

Open aifeixingdelv opened 1 week ago

aifeixingdelv commented 1 week ago

When I use the BackboneVSSM inference,meet the above error: /root/miniconda3/bin/python3 /root/autodl-tmp/monodepth2/test.py Successfully load ckpt models/vssm1_tiny_0230s_ckpt_epoch_264.pth _IncompatibleKeys(missing_keys=['outnorm0.weight', 'outnorm0.bias', 'outnorm1.weight', 'outnorm1.bias', 'outnorm2.weight', 'outnorm2.bias', 'outnorm3.weight', 'outnorm3.bias'], unexpected_keys=['classifier.norm.weight', 'classifier.norm.bias', 'classifier.head.weight', 'classifier.head.bias']) Traceback (most recent call last): File "", line 21, in triton_cross_scan_flex KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32), (0, 0, 0, 0, 0, 1, 32, 32, 96, 56, 56, 2, 2), (True, True))

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 937, in build_triton_ir generator.visit(fn.parse()) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 183, in visit_Module ast.NodeVisitor.generic_visit(self, node) File "/root/miniconda3/lib/python3.8/ast.py", line 379, in generic_visit self.visit(item) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 252, in visit_FunctionDef has_ret = self.visit_compound_statement(node.body) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 177, in visit_compound_statement self.last_ret_type = self.visit(stmt) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 301, in visit_Assign values = self.visit(node.value) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 339, in visit_BinOp rhs = self.visit(node.right) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 492, in visit_IfExp if cond.value: AttributeError: 'bool' object has no attribute 'value'

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/autodl-tmp/monodepth2/test.py", line 7, in outs = backbone(input) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1836, in forward o, x = layer_forward(layer, x) # (B, H, W, C) File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1829, in layer_forward x = l.blocks(x) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1238, in forward return self._forward(input) File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1226, in _forward x = x + self.drop_path(self.op(self.norm(x))) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 659, in forwardv2 y = self.forward_core(x) File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 605, in forward_corev2 xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch) File "/root/autodl-tmp/monodepth2/networks/csm_triton.py", line 495, in cross_scan_fn return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(args, *kwargs) # type: ignore[misc] File "/root/autodl-tmp/monodepth2/networks/csm_triton.py", line 417, in forward triton_cross_scan_flex[(NH NW, NC, B)]( File "", line 41, in triton_cross_scan_flex File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 1620, in compile next_module = compile(module) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 1549, in lambda src: ast_to_ttir(src, signature, configs[0], constants)), File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 962, in ast_tottir mod, = build_triton_ir(fn, signature, specialization, constants) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 942, in build_triton_ir raise CompilationError(fn.src, node) from e triton.compiler.CompilationError: at 45:74: def triton_cross_scan_flex( x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C) y, # (B, 4, C, H, W) | (B, H, W, 4, C) x_layout: tl.constexpr, y_layout: tl.constexpr, operation: tl.constexpr, onebyone: tl.constexpr, scans: tl.constexpr, BC: tl.constexpr, BH: tl.constexpr, BW: tl.constexpr, DC: tl.constexpr, DH: tl.constexpr, DW: tl.constexpr, NH: tl.constexpr, NW: tl.constexpr, ):

x_layout = 0

# y_layout = 1 # 0 BCHW, 1 BHWC
# operation = 0 # 0 scan, 1 merge
# onebyone = 0 # 0 false, 1 true
# scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_w = (i_hw // NW), (i_hw % NW)
_mask_h = (i_h * BH + tl.arange(0, BH)) < DH
_mask_w = (i_w * BW + tl.arange(0, BW)) < DW
_mask_hw = _mask_h[:, None] & _mask_w[None, :]
_for_C = min(DC - i_c * BC, BC)

HWRoute0 = i_h * BH * DW  + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :]
HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None]  # trans
HWRoute2 = (NH - i_h - 1) * BH * DW  + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip
HWRoute3 = (NW - i_w - 1) * BW * DH  + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH  # trans + flip

if scans == 1:
    HWRoute1 = HWRoute0
    HWRoute2 = HWRoute0
    HWRoute3 = HWRoute0
elif scans == 2:
    HWRoute1 = HWRoute0
    HWRoute3 = HWRoute2        

_tmp1 = DC * DH * DW

y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
                                                                      ^

进程已结束,退出代码为 1

aifeixingdelv commented 1 week ago

When I python csm_triton.py(CHECK.check_csm_triton()), get the same error

MzeroMiko commented 1 week ago

It seems that you passed a parameter which is of type bool into the function cross_scan_fn.

I do run the code csm_triton.py successfully, and are you sure the file you use is exactly the same as it is in this repo?

aifeixingdelv commented 1 week ago

I solved it. My triton version is 2.0.0. After updating to version 2.3.1, the error is solved

russellllaputa commented 1 week ago

I just had the same problem, but was solved by simply creating a new virtual env with python==3.10.1, pytorch==2.3.1