Hello,
thank you for your interesting work. I want to use a pretrained model of yours but i face an error. I use videomamba.py file for loading the model in my model. I removed the classification head in videomamba.py:
```
.
.
.
def forward_features(self, x, inference_params=None):
x = self.patch_embed(x)
B, C, T, H, W = x.shape
x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_token, x), dim=1)
x = x + self.pos_embed
# temporal pos
cls_tokens = x[:B, :1, :]
x = x[:, 1:]
x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
x = x + self.temporal_pos_embedding
x = rearrange(x, '(b n) t m -> b (t n) m', b=B, t=T)
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
# mamba impl
residual = None
hidden_states = x
for idx, layer in enumerate(self.layers):
if self.use_checkpoint and idx < self.checkpoint_num:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params,
use_checkpoint=True
)
else:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
if not self.fused_add_norm:
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states)
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
self.drop_path(hidden_states),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
# return only cls token
return hidden_states[:, 0, :]
def forward(self, x, inference_params=None):
x1= self.forward_features(x, inference_params)
return x1
```
Hello, thank you for your interesting work. I want to use a pretrained model of yours but i face an error. I use videomamba.py file for loading the model in my model. I removed the classification head in videomamba.py:
What is wrong?