Open aifeixingdelv opened 1 week ago
When I python csm_triton.py(CHECK.check_csm_triton()), get the same error
It seems that you passed a parameter which is of type bool into the function cross_scan_fn
.
I do run the code csm_triton.py
successfully, and are you sure the file you use is exactly the same as it is in this repo?
I solved it. My triton version is 2.0.0. After updating to version 2.3.1, the error is solved
I just had the same problem, but was solved by simply creating a new virtual env with python==3.10.1, pytorch==2.3.1
When I use the BackboneVSSM inference,meet the above error: /root/miniconda3/bin/python3 /root/autodl-tmp/monodepth2/test.py Successfully load ckpt models/vssm1_tiny_0230s_ckpt_epoch_264.pth _IncompatibleKeys(missing_keys=['outnorm0.weight', 'outnorm0.bias', 'outnorm1.weight', 'outnorm1.bias', 'outnorm2.weight', 'outnorm2.bias', 'outnorm3.weight', 'outnorm3.bias'], unexpected_keys=['classifier.norm.weight', 'classifier.norm.bias', 'classifier.head.weight', 'classifier.head.bias']) Traceback (most recent call last): File "", line 21, in triton_cross_scan_flex
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32), (0, 0, 0, 0, 0, 1, 32, 32, 96, 56, 56, 2, 2), (True, True))
During handling of the above exception, another exception occurred:
Traceback (most recent call last): File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 937, in build_triton_ir generator.visit(fn.parse()) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 183, in visit_Module ast.NodeVisitor.generic_visit(self, node) File "/root/miniconda3/lib/python3.8/ast.py", line 379, in generic_visit self.visit(item) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 252, in visit_FunctionDef has_ret = self.visit_compound_statement(node.body) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 177, in visit_compound_statement self.last_ret_type = self.visit(stmt) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 301, in visit_Assign values = self.visit(node.value) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 339, in visit_BinOp rhs = self.visit(node.right) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 855, in visit return super().visit(node) File "/root/miniconda3/lib/python3.8/ast.py", line 371, in visit return visitor(node) File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 492, in visit_IfExp if cond.value: AttributeError: 'bool' object has no attribute 'value'
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/root/autodl-tmp/monodepth2/test.py", line 7, in
outs = backbone(input)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1836, in forward
o, x = layer_forward(layer, x) # (B, H, W, C)
File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1829, in layer_forward
x = l.blocks(x)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, kwargs)
File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1238, in forward
return self._forward(input)
File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 1226, in _forward
x = x + self.drop_path(self.op(self.norm(x)))
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 659, in forwardv2
y = self.forward_core(x)
File "/root/autodl-tmp/monodepth2/networks/mamba_encoder.py", line 605, in forward_corev2
xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
File "/root/autodl-tmp/monodepth2/networks/csm_triton.py", line 495, in cross_scan_fn
return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(args, *kwargs) # type: ignore[misc]
File "/root/autodl-tmp/monodepth2/networks/csm_triton.py", line 417, in forward
triton_cross_scan_flex[(NH NW, NC, B)](
File "", line 41, in triton_cross_scan_flex
File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 1620, in compile
next_module = compile(module)
File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 1549, in
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 962, in ast_tottir
mod, = build_triton_ir(fn, signature, specialization, constants)
File "/root/miniconda3/lib/python3.8/site-packages/triton/compiler.py", line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 45:74:
def triton_cross_scan_flex(
x, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
y, # (B, 4, C, H, W) | (B, H, W, 4, C)
x_layout: tl.constexpr,
y_layout: tl.constexpr,
operation: tl.constexpr,
onebyone: tl.constexpr,
scans: tl.constexpr,
BC: tl.constexpr,
BH: tl.constexpr,
BW: tl.constexpr,
DC: tl.constexpr,
DH: tl.constexpr,
DW: tl.constexpr,
NH: tl.constexpr,
NW: tl.constexpr,
):
x_layout = 0
进程已结束,退出代码为 1