MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
1.82k stars 98 forks source link

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) #207

Open BranStarkkk opened 1 month ago

BranStarkkk commented 1 month ago

作者您好,最近我将您工作的VMamba_tiny作为backbone应用于下游的分类任务并加载预训练权重进行训练了,我的训练环境是python3.8,cuda11.8。然后将训练好的模型放置于kaggle上配置为cuda12.2,python3.10的环境中进行测试,发现无论如何都会报错如下:

ValueError Traceback (most recent call last) Cell In[57], line 22 20 for model in models: 21 with torch.cuda.amp.autocast(): ---> 22 ifr = model(ipt.cuda(1), len(ipt)) 23 # res.append(ifr[0].float32() * confidence_scales) 24 res.append(ifr[0].float32()) # cell

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

Cell In[40], line 93, in DualBranchMILVMamba.forward(self, x, cnt) 91 def forward(self, x, cnt): 92 self.model = self.model.cuda(1) ---> 93 features = self.model(x.cuda(1)) 94 #features = self.model(x) 95 pooled = features

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/lib/python3.10/site-packages/vmamba/classification/models/vmamba.py:1368, in VSSM.forward(self, x) 1366 x = x + pos_embed 1367 for layer in self.layers: -> 1368 x = layer(x) 1369 x = self.classifier(x) 1370 return x

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input) 213 def forward(self, input): 214 for module in self: --> 215 input = module(input) 216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input) 213 def forward(self, input): 214 for module in self: --> 215 input = module(input) 216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/lib/python3.10/site-packages/vmamba/classification/models/vmamba.py:1110, in VSSBlock.forward(self, input) 1108 def forward(self, input: torch.Tensor): 1109 if self.use_checkpoint: -> 1110 return checkpoint.checkpoint(self._forward, input) 1111 else: 1112 return self._forward(input)

File /opt/conda/lib/python3.10/site-packages/torch/_compile.py:24, in _disable_dynamo..inner(*args, kwargs) 20 @functools.wraps(fn) 21 def inner(*args, *kwargs): 22 import torch._dynamo ---> 24 return torch._dynamo.disable(fn, recursive)(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.call.._fn(*args, *kwargs) 326 dynamic_ctx.enter() 327 try: --> 328 return fn(args, **kwargs) 329 finally: 330 set_eval_frame(prior)

File /opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:17, in wrap_inline..inner(*args, kwargs) 15 @functools.wraps(fn) 16 def inner(*args, *kwargs): ---> 17 return fn(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:451, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, kwargs) 446 if context_fn is not noop_context_fn or debug is not False: 447 raise ValueError( 448 "Passing context_fn or debug is only supported when " 449 "use_reentrant=False." 450 ) --> 451 return CheckpointFunction.apply(function, preserve, args) 452 else: 453 gen = _checkpoint_without_reentrant_generator( 454 function, preserve, context_fn, determinism_check, debug, args, kwargs 455 )

File /opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, *kwargs) 536 if not torch._C._are_functorch_transforms_active(): 537 # See NOTE: [functorch vjp and autograd interaction] 538 args = _functorch.utils.unwrap_dead_wrappers(args) --> 539 return super().apply(args, **kwargs) # type: ignore[misc] 541 if cls.setup_context == _SingleLevelFunction.setup_context: 542 raise RuntimeError( 543 "In order to use an autograd.Function with functorch transforms " 544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 545 "staticmethod. For more details, please see " 546 "https://pytorch.org/docs/master/notes/extending.func.html" 547 )

File /opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:230, in CheckpointFunction.forward(ctx, run_function, preserve_rng_state, args) 227 ctx.save_for_backward(tensor_inputs) 229 with torch.no_grad(): --> 230 outputs = run_function(*args) 231 return outputs

File /opt/conda/lib/python3.10/site-packages/vmamba/classification/models/vmamba.py:1100, in VSSBlock._forward(self, input) 1098 x = x + self.drop_path(self.norm(self.op(x))) 1099 else: -> 1100 x = x + self.drop_path(self.op(self.norm(x))) 1101 if self.mlp_branch: 1102 if self.post_norm:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/lib/python3.10/site-packages/vmamba/classification/models/vmamba.py:704, in SS2Dv2.forwardv2(self, x, **kwargs) 702 x = self.conv2d(x) # (b, d, h, w) 703 x = self.act(x) --> 704 y = self.forward_core(x) 705 y = self.out_act(y) 706 if not self.disable_z:

File /opt/conda/lib/python3.10/site-packages/vmamba/classification/models/vmamba.py:650, in SS2Dv2.forward_corev2(self, x, to_dtype, force_fp32, ssoflex, SelectiveScan, CrossScan, CrossMerge, no_einsum, cascade2d, **kwargs) 648 y = y_col 649 else: --> 650 xs = CrossScan.apply(x) 651 if no_einsum: 652 x_dbl = F.conv1d(xs.view(B, -1, L), x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)

File /opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, *kwargs) 536 if not torch._C._are_functorch_transforms_active(): 537 # See NOTE: [functorch vjp and autograd interaction] 538 args = _functorch.utils.unwrap_dead_wrappers(args) --> 539 return super().apply(args, **kwargs) # type: ignore[misc] 541 if cls.setup_context == _SingleLevelFunction.setup_context: 542 raise RuntimeError( 543 "In order to use an autograd.Function with functorch transforms " 544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 545 "staticmethod. For more details, please see " 546 "https://pytorch.org/docs/master/notes/extending.func.html" 547 )

File /opt/conda/lib/python3.10/site-packages/vmamba/classification/models/csm_triton.py:174, in CrossScanTriton.forward(ctx, x) 172 x = x.contiguous() 173 y = x.new_empty((B, 4, C, H, W)) --> 174 triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 175 return y.view(B, 4, C, -1)

File /opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py:167, in KernelInterface.getitem..(*args, kwargs) 161 def getitem(self, grid) -> T: 162 """ 163 A JIT function is launched with: fn[grid](*args, *kwargs). 164 Hence JITFunction.getitem returns a callable proxy that 165 memorizes the grid. 166 """ --> 167 return lambda args, kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py:425, in JITFunction.run(self, grid, warmup, *args, *kwargs) 423 if not warmup: 424 args = [arg.value for arg in args if not arg.param.is_constexpr] --> 425 kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance 426 kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster 427 kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook, 428 CompiledKernel.launch_exit_hook, kernel, 429 driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args)) 430 return kernel

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

恳求作者给予哪怕任何的建议与解决办法,不胜感激!

MzeroMiko commented 1 month ago

have you transferred the model into cuda?

BranStarkkk commented 1 month ago

have you transferred the model into cuda?

Now I have solved the problem, and I think maybe the reason for this problem have something to do with the cpu I was using when loading the pre-trained VMamba model. Anyway, thanks for the answer!