OpenGVLab / VideoMamba

[ECCV2024] VideoMamba: State Space Model for Efficient Video Understanding
https://arxiv.org/abs/2403.06977
Apache License 2.0
847 stars 60 forks source link

ValueError("arange's arguments must be of type tl.constexpr") #52

Open Sine7812 opened 6 months ago

Sine7812 commented 6 months ago

Does anyone know how to view the model's architecture diagram? Whether it's using TensorBoard or attempting to convert to ONNX format, I'm encountering errors.Traceback (most recent call last): File "/mnt/e/code/VideoMamba-main/videomamba/video_sm/run_class_finetuning.py", line 744, in main(opts, ds_init) File "/mnt/e/code/VideoMamba-main/videomamba/video_sm/run_class_finetuning.py", line 674, in main save_model_and_onnx(args, model_without_ddp, epoch, model_ema) File "/mnt/e/code/VideoMamba-main/videomamba/video_sm/run_class_finetuning.py", line 249, in save_model_and_onnx
torch.onnx.export(model, dummy_input, onnx_path, export_params=True, opset_version=11, File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export _export( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export graph, params_dict, torch_out = _model_to_graph( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph outs = ONNXTracedModule( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 133, in forward graph, out = torch._C._create_graph_by_tracing( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 124, in wrapper outs.append(self.inner(trace_inputs)) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward result = self.forward(*input, kwargs) File "/mnt/e/code/VideoMamba-main/videomamba/video_sm/models/videomamba.py", line 417, in forward x = self.forward_features(x, inference_params) File "/mnt/e/code/VideoMamba-main/videomamba/video_sm/models/videomamba.py", line 390, in forward_features hidden_states, residual = layer( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward result = self.forward(*input, kwargs) File "/mnt/e/code/VideoMamba-main/videomamba/video_sm/models/videomamba.py", line 123, in forward hidden_states, residual = fused_add_norm_fn( File "/mnt/e/code/VideoMamba-main/mamba/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, *kwargs) # type: ignore[misc] File "/mnt/e/code/VideoMamba-main/mamba/mamba_ssm/ops/triton/layernorm.py", line 411, in forward y, mean, rstd, residual_out = _layer_norm_fwd( File "/mnt/e/code/VideoMamba-main/mamba/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd _layer_norm_fwd_1pass_kernel[(M,)]( File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(
args, config=config,
kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in timings = {config: self._bench(*args, config=config, *kwargs) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench fn() File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call self.fn.run(
args, num_warps=config.num_warps, num_stages=config.num_stages, **current) File "", line 63, in _layer_norm_fwd_1pass_kernel File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module) File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/home/jcz/anaconda3/envs/mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir raise CompilationError(fn.src, node, repr(e)) from e triton.compiler.errors.CompilationError: at 31:24: HAS_BIAS: tl.constexpr, ):

Map the program id to the row of X and Y it should compute.

row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
    RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
    RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
                    ^

ValueError("arange's arguments must be of type tl.constexpr")

SophieOstmeier commented 6 months ago

I have the same issue. Does anyone have a solution? Help is greatly appreciated.

zengjie617789 commented 2 weeks ago

anybody meet the problem? anyone got solved?

dmenig commented 2 weeks ago

Same problem here

dmenig commented 2 weeks ago

Found the solution. Replace this :

def _layer_norm_fwd(
    x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
    if residual is not None:
        residual_dtype = residual.dtype
    M, N = x.shape
    assert x.stride(-1) == 1
    if residual is not None:
        assert residual.stride(-1) == 1
        assert residual.shape == (M, N)
    assert weight.shape == (N,)
    assert weight.stride(-1) == 1
    if bias is not None:
        assert bias.stride(-1) == 1
        assert bias.shape == (N,)
    # allocate output
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
    assert y.stride(-1) == 1
    if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
        residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
        assert residual_out.stride(-1) == 1
    else:
        residual_out = None
    mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    # heuristics for number of warps
    with torch.cuda.device(x.device.index):
        _layer_norm_fwd_1pass_kernel[(M,)](
            x,
            y,
            weight,
            bias,
            residual,
            residual_out,
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
            N,
            eps,
            is_rms_norm,
            BLOCK_N,
            residual is not None,
            residual_out is not None,
            bias is not None,
        )
    # residual_out is None if residual is None and residual_dtype == input_dtype
    return y, mean, rstd, residual_out if residual_out is not None else x

by

def _layer_norm_fwd(
    x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
    if residual is not None:
        residual_dtype = residual.dtype
    M, N = x.shape
    assert x.stride(-1) == 1
    if residual is not None:
        assert residual.stride(-1) == 1
        assert residual.shape == (M, N)
    assert weight.shape == (N,)
    assert weight.stride(-1) == 1
    if bias is not None:
        assert bias.stride(-1) == 1
        assert bias.shape == (N,)
    # allocate output
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
    assert y.stride(-1) == 1
    if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
        residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
        assert residual_out.stride(-1) == 1
    else:
        residual_out = None
    mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    # heuristics for number of warps
    with torch.cuda.device(x.device.index):
        _layer_norm_fwd_1pass_kernel[(M,)](
            x,
            y,
            weight,
            bias,
            residual,
            residual_out,
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
            int(N),
            eps,
            is_rms_norm,
            int(BLOCK_N),
            residual is not None,
            residual_out is not None,
            bias is not None,
        )
    # residual_out is None if residual is None and residual_dtype == input_dtype
    return y, mean, rstd, residual_out if residual_out is not None else x

= cast both N and BLOCK_N to int. The error shown by OP only points to casting BLOCK_N to int, but if you only did that, you'd see another error, which is solved by also casting N to int.

dmenig commented 2 weeks ago

Sorry : lemme add this though : even with these corrections, I still couldn't export to onnx without corrupting the model. Those lines will only allow the torch.jit.trace call. Weirdly, the traced model is not corrupt, but the onnx one is for me.

fcchit commented 1 week ago

@SophieOstmeier @zengjie617789 Have you solved this problem? It has been bothering me for days.