mit-han-lab / temporal-shift-module

[ICCV 2019] TSM: Temporal Shift Module for Efficient Video Understanding
https://arxiv.org/abs/1811.08383
MIT License
2.05k stars 418 forks source link

what is the correct shape for TSM expected? #204

Open mbilalshaikh opened 2 years ago

mbilalshaikh commented 2 years ago

params = [] for child in list(model.children())[:-1]: params.extend(list(child.parameters())) model.avg_pool = nn.Identity() print(len(X),X.shape)

shape of x: [N, T, C, H, W]

d = torch.rand((64,8,3,256,256),dtype=torch.float32).to(device)
preds = model(d,8)

File "/home/muhammadbsheikh/anaconda3/envs/open-mmlab2/lib/python3.7/site-packages/torch/nn/functional.py", line 1370, in linear ret = torch.addmm(bias, input, weight.t()) RuntimeError: size mismatch, m1: [64 x 1572864], m2: [2048 x 101] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:290 (open-mmlab2) muhammadbsheikh@LambdaBlade:~/workspace/projects/AV_HAR_MBS$

junjun1023 commented 2 years ago

Hi, according the implementation of TSM,

  @staticmethod
  def shift(x, n_segment, fold_div=3, inplace=False):
      nt, c, h, w = x.size()
      n_batch = nt // n_segment
      x = x.view(n_batch, n_segment, c, h, w)

input tensor size should be (n_batch * n_segment, C, H, W), and (N*T, C, H, W) in your case