OpenGVLab / VideoMamba

VideoMamba: State Space Model for Efficient Video Understanding
https://arxiv.org/abs/2403.06977
Apache License 2.0
660 stars 47 forks source link

How to use VideoMamba as a backbone network #45

Open edofazza opened 2 months ago

edofazza commented 2 months ago

Hi,

I would like to use VideoMamba similarly to how TimeSformer is used as backbones, using the videomamba_m16_k400_f16_res224 weights. The idea is to substitute: self.backbone = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400", num_frames=num_frames, ignore_mismatched_sizes=True) with something using your library obtaining an output that is 3 dimensional (batch, features1, features2) which I can further use in the network.

Can you please direct me which file to use or how to do that?

Thank you.

Andy1621 commented 2 months ago

You can use videomamba_middle here.

Andy1621 commented 2 months ago

Note that you need to change the pretrained_path since it will load the ImageNet pretraining by default.

Huangmr0719 commented 2 months ago

You can use videomamba_middle here.

Can I please ask I want to use VideoMamba-Base with dim=768 as the visual backbone, but I'm not quite sure how to load your open source model weights file videomamba_b16_in1k_res224.pth.Here's the code I've modified a bit.

def videomamba_base(pretrained=True, **kwargs):
    model = VisionMamba(
        patch_size=16, 
        embed_dim=768,
        depth=24, 
        rms_norm=True, 
        residual_in_fp32=True, 
        fused_add_norm=True, 
        **kwargs
    )
    model.default_cfg = _cfg()
        if pretrained:
        print("loading pretained model")
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://huggingface.co/OpenGVLab/VideoMamba/resolve/main/videomamba_b16_in1k_res224.pth?download=true",
            model_dir = "./ckpt",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"], False)

I get the following error when I run this code.

Traceback (most recent call last):
  File "/root/data/visfeature_extract_m/run_extract_vm.py", line 90, in <module>
    model.load_pretrained("/root/data/visfeature_extract_m/ckpt/videomamba_b16_in1k_res224.pth")
  File "/root/data/visfeature_extract_m/model/videomamba.py", line 279, in load_pretrained
    _load_weights(self, checkpoint_path, prefix)
  File "/opt/anaconda3/envs/mamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/anaconda3/envs/mamba/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 440, in _load_weights
    model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
  File "/opt/anaconda3/envs/mamba/lib/python3.10/site-packages/numpy/lib/npyio.py", line 263, in __getitem__
    raise KeyError(f"{key} is not a file in the archive")
KeyError: 'embedding/kernel is not a file in the archive'

And I use the load_state_dict and inflate_weight in the code,with the code

    if pretrained:
        print('load pretrained weights')
        state_dict = torch.load("./ckpt/videomamba_b16_in1k_res224.pth", map_location='cpu')
        load_state_dict(model, state_dict, center=True)

but it shows

Traceback (most recent call last):
  File "/root/data/visfeature_extract_m/run_extract_vm.py", line 89, in <module>
    model= videomamba_base()
  File "/root/data/visfeature_extract_m/model/videomamba.py", line 431, in videomamba_base
    load_state_dict(model, state_dict, center=True)
  File "/root/data/visfeature_extract_m/model/videomamba.py", line 347, in load_state_dict
    del state_dict['head.weight']
KeyError: 'head.weight'
Andy1621 commented 2 months ago

Hi! You should use state_dict['model']. Besides, del state_dict['head.weight'] should be commented.

Huangmr0719 commented 2 months ago

Hi! You should use state_dict['model']. Besides, del state_dict['head.weight'] should be commented.

Thank you very much! I've been debugging it all afternoon and it's finally working!

edofazza commented 1 month ago

I tried to use videomamba_middle changing the training path accordingly. However, I would like to use the model till the DropPath-99 layer, which is the last of layers, excluding in this way norm_f. If I try to do something like this:

      backbone = videomamba_middle(num_frames=num_frames)
      features = []
      for name, layer in backbone.named_children():
          features.append(layer)
          print(name)
          if name == 'layers':
              break
     self.backbone = nn.Sequential(*features)

And in the forward calling:

x = self.backbone(images.reshape(b, c, t, h, w))

I receive the following error:

Traceback (most recent call last):
  File "main.py", line 116, in <module>
    main(args)
  File "main.py", line 86, in main
    executor.train(args.epoch_start, args.epochs)
  File "VideoMambaCLIPInitVideoGuideMamba.py", line 133, in train
    self._train_epoch(epoch)
  File "VideoMambaCLIPInitVideoGuideMamba.py", line 122, in _train_epoch
    loss_this = self._train_batch(data, label)
  File "VideoMambaCLIPInitVideoGuideMamba.py", line 110, in _train_batch
    output = self.model(data)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "VideoMambaCLIPInitVideoGuideMamba.py", line 50, in forward
    x = self.backbone(images.reshape(b, c, t, h, w))
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1032192x14 and 576x1000)