OpenGVLab / VideoMamba

[ECCV2024] VideoMamba: State Space Model for Efficient Video Understanding
https://arxiv.org/abs/2403.06977
Apache License 2.0
796 stars 59 forks source link

raise ValueError("arange's arguments must be of type tl.constexpr") #8

Closed cyinen closed 6 months ago

cyinen commented 6 months ago

I have followed your installation guide to create an environment, but when I run "python videomamba.py" I met this error. Can you give me some suggestions? Thank you!

(mamba) ➜  models git:(main)  python videomamba.py
Traceback (most recent call last):
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 293, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 288, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 946, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/language/core.py", line 30, in wrapper
    return fn(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/language/core.py", line 813, in arange
    return semantic.arange(start, end, _builder)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/language/semantic.py", line 485, in arange
    raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/chenyin/project/VideoMamba/videomamba/image_sm/models/videomamba.py", line 405, in <module>
    print(flop_count_table(flops, max_depth=1))
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/fvcore/nn/print_model_statistics.py", line 632, in flop_count_table
    stats = {params_header: params, flops_header: flops.by_module()}
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 291, in by_module
    stats = self._analyze()
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 551, in _analyze
    graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 176, in _get_scoped_trace_graph
    graph, _ = _get_trace_graph(module, inputs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/data/chenyin/project/VideoMamba/videomamba/image_sm/models/videomamba.py", line 320, in forward
    x = self.forward_features(x, inference_params)
  File "/data/chenyin/project/VideoMamba/videomamba/image_sm/models/videomamba.py", line 293, in forward_features
    hidden_states, residual = layer(
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/data/chenyin/project/VideoMamba/videomamba/image_sm/models/videomamba.py", line 98, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/data/chenyin/project/VideoMamba/mamba/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/data/chenyin/project/VideoMamba/mamba/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/data/chenyin/project/VideoMamba/mamba/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
    fn()
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 63, in _layer_norm_fwd_1pass_kernel
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/home/chenyin/miniconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 31:24:    HAS_BIAS: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
    if STORE_RESIDUAL_OUT:
        RESIDUAL_OUT += row * stride_res_out_row
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
                        ^
ValueError("arange's arguments must be of type tl.constexpr")
Andy1621 commented 6 months ago

Hi, for jit_analysis, there are some bugs when setting rms_norm=True and fused_add_norm=True. Please set them to False:

https://github.com/OpenGVLab/VideoMamba/blob/ea4d3f545ac8ba414dd474595cdf5450f55e4c7d/videomamba/video_sm/models/videomamba.py#L472-L477

For training, you can set them to True since no jit_analysis is called.

cyinen commented 6 months ago

@Andy1621 Thanks, I have found it!