OpenGVLab / VideoMamba

VideoMamba: State Space Model for Efficient Video Understanding
https://arxiv.org/abs/2403.06977
Apache License 2.0
660 stars 47 forks source link

Question about training a custom dataset based on the Video Mamba middle model #55

Closed kksoo1769 closed 1 month ago

kksoo1769 commented 1 month ago

Hello. I am a beginner interested in computer vision. I would like to train the Video Mamba middle model on my video data. Could you please provide a detailed method on how to do this?

Thank you.

Andy1621 commented 1 month ago

Good question! But I'm too busy to give detailed instructions.

You can follow our training scripts and prepare your dataset in the same format as ours.

Don't hesitate to ask questions if you are in trouble.

kksoo1769 commented 1 month ago

Thank you!

I encounter an unexpected error while executing videomamba.py:

(video_mamba) (base) kks@xvoice:~/workspace/models/VideoMamba/official$ /home/kks/anaconda3/envs/video_mamba/bin/python /home/kks/workspace/models/VideoMamba/my_model/videomamba.py
Use checkpoint: False
Checkpoint number: 0
Traceback (most recent call last):
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 293, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 288, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 946, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/language/core.py", line 30, in wrapper
    return fn(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/language/core.py", line 813, in arange
    return semantic.arange(start, end, _builder)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/language/semantic.py", line 485, in arange
    raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/kks/workspace/models/VideoMamba/my_model/videomamba.py", line 476, in <module>
    print(flop_count_table(flops, max_depth=1))
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/fvcore/nn/print_model_statistics.py", line 632, in flop_count_table
    stats = {params_header: params, flops_header: flops.by_module()}
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 291, in by_module
    stats = self._analyze()
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 551, in _analyze
    graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/fvcore/nn/jit_analysis.py", line 176, in _get_scoped_trace_graph
    graph, _ = _get_trace_graph(module, inputs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/kks/workspace/models/VideoMamba/my_model/videomamba.py", line 366, in forward
    x = self.forward_features(x, inference_params)
  File "/home/kks/workspace/models/VideoMamba/my_model/videomamba.py", line 339, in forward_features
    hidden_states, residual = layer(
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/kks/workspace/models/VideoMamba/my_model/videomamba.py", line 82, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/home/kks/workspace/models/VideoMamba/official/mamba/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/kks/workspace/models/VideoMamba/official/mamba/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/home/kks/workspace/models/VideoMamba/official/mamba/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
    fn()
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 63, in _layer_norm_fwd_1pass_kernel
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/home/kks/anaconda3/envs/video_mamba/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 31:24:    HAS_BIAS: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
    if STORE_RESIDUAL_OUT:
        RESIDUAL_OUT += row * stride_res_out_row
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
                        ^
ValueError("arange's arguments must be of type tl.constexpr")

How can I solve this?

Andy1621 commented 1 month ago

Can you provide your environment?

kksoo1769 commented 1 month ago

I'm using linux and python==3.10.14.

Here are my pip lists.

Package Version Editable project location


absl-py 2.1.0 accelerate 0.30.1 aiohttp 3.9.5 aiosignal 1.3.1 antlr4-python3-runtime 4.9.3 appdirs 1.4.4 async-timeout 4.0.3 attrs 23.2.0 av 11.0.0 blis 0.7.11 catalogue 2.0.10 causal-conv1d 1.0.0 /home/kks/workspace/models/VideoMamba/VideoMamba/causal-conv1d certifi 2022.12.7 chardet 5.2.0 charset-normalizer 2.1.1 click 8.1.7 cloudpathlib 0.16.0 cloudpickle 3.0.0 colorama 0.4.6 confection 0.1.4 cymem 2.0.8 DataProperty 1.0.1 datasets 2.19.1 de-core-news-sm 3.0.0 decord 0.6.0 deepspeed 0.13.1 dill 0.3.8 docker-pycreds 0.4.0 einops 0.7.0 en-core-web-sm 3.0.0 evaluate 0.4.2 exceptiongroup 1.2.1 filelock 3.13.1 frozenlist 1.4.1 fsspec 2024.2.0 ftfy 6.1.3 fvcore 0.1.5.post20221221 gitdb 4.0.11 GitPython 3.1.43 hjson 3.1.0 huggingface-hub 0.23.0 idna 3.4 imageio 2.33.1 iniconfig 2.0.0 iopath 0.1.10 Jinja2 3.1.3 joblib 1.4.2 jsonlines 4.0.0 langcodes 3.3.0 lazy_loader 0.4 lm_eval 0.4.1 lxml 5.2.2 mamba-ssm 1.0.1 /home/kks/workspace/models/VideoMamba/VideoMamba/mamba MarkupSafe 2.1.5 mbstrdecoder 1.1.3 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.16 murmurhash 1.0.10 networkx 3.2.1 ninja 1.11.1.1 nltk 3.8.1 numexpr 2.10.0 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.19.3 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.1.105 omegaconf 2.3.0 opencv-python 4.8.1.78 packaging 24.0 pandas 2.2.1 pathlib_abc 0.1.1 pathvalidate 3.2.0 pathy 0.11.0 peft 0.10.0 Pillow 10.1.0 pip 24.0 pluggy 1.5.0 portalocker 2.8.2 preshed 3.0.9 protobuf 4.25.3 psutil 5.9.8 py-cpuinfo 9.0.0 pyarrow 16.1.0 pyarrow-hotfix 0.6 pybind11 2.12.0 pydantic 1.8.2 pynvml 11.5.0 pytablewriter 1.2.0 pytest 8.1.1 python-dateutil 2.9.0.post0 pytz 2024.1 PyYAML 6.0.1 regex 2023.10.3 requests 2.31.0 responses 0.18.0 rouge_score 0.1.2 sacrebleu 2.4.2 safetensors 0.4.3 scikit-image 0.23.2 scikit-learn 1.4.2 scipy 1.12.0 sentry-sdk 2.1.1 setproctitle 1.3.3 setuptools 68.2.2 six 1.16.0 smart-open 6.4.0 smmap 5.0.1 spacy 3.7.4 spacy-legacy 3.0.12 spacy-loggers 1.0.5 sqlitedict 2.1.0 srsly 2.4.8 submitit 1.5.1 sympy 1.12 tabledata 1.3.3 tabulate 0.9.0 tcolorpy 0.1.6 tensorboardX 2.6.2.2 termcolor 2.4.0 thinc 8.2.3 threadpoolctl 3.5.0 tifffile 2024.5.10 timm 0.4.12 tokenizers 0.15.2 tomli 2.0.1 torch 2.1.1+cu118 torchaudio 2.1.1+cu118 torchtext 0.12.0 torchvision 0.16.1+cu118 tqdm 4.66.1 tqdm-multiprocess 0.0.11 transformers 4.36.1 triton 2.1.0 typepy 1.3.2 typer 0.3.2 typing_extensions 4.9.0 tzdata 2024.1 urllib3 1.26.13 wandb 0.16.2 wasabi 0.10.1 wcwidth 0.2.13 weasel 0.3.4 wheel 0.42.0 xformers 0.0.23+cu118 xxhash 3.4.1 yacs 0.1.8 yarl 1.9.4 zstandard 0.22.0

Andy1621 commented 1 month ago

It seems that you directly run the videomamba.py.

Please change the code as here.

kksoo1769 commented 1 month ago

Thank you!! It is solved.