yjxiong / tsn-pytorch

Temporal Segment Networks (TSN) in PyTorch
BSD 2-Clause "Simplified" License
1.06k stars 311 forks source link

size mistmatch? #100

Open HiIcy opened 4 years ago

HiIcy commented 4 years ago

I have pre-download bn_inception.pth,and part of error info as follow:

File "D:\hiicy\documents\files\tsn-pytorch\tf_model_zoo\bninception\pytorch_load.py", line 35, in init self.load_state_dict(model_zoo.load_url(weight_url)) File "D:\hiicy\Anaconda\envs\red\lib\site-packages\torch\nn\modules\module.py", line 845, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for BNInception: size mismatch for conv1_7x7_s2_bn.weight: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv1_7x7_s2_bn.bias: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv1_7x7_s2_bn.running_mean: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv1_7x7_s2_bn.running_var: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv2_3x3_reduce_bn.weight: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv2_3x3_reduce_bn.bias: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for conv2_3x3_reduce_bn.running_mean: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([64]).

I couldn't know how to solve it ,appreciate any help!

immaping commented 4 years ago

you should use torch == 0.3.1 and torchvision == 0.2.1

HiIcy commented 4 years ago

you should use torch == 0.3.1 and torchvision == 0.2.1

I'm sorry to reply you! Okay,I'll try it ,thank you!

mxl1990 commented 2 years ago

For higher version pytorch > 0.3.1,you can change code to convert checkpoint to avoid this error. For example, you chose BNInception, so find tf_model_zoo\bninception\pytorch_load.py code in init function state_dict = torch.utils.model_zoo.load_url(weight_url) state_dict = self.convert_state_dict(state_dict) self.load_state_dict(state_dict) then add a function after init function `def convert_state_dict(self, state_dict): cv_state_dict = {} for key in state_dict: current_tensor = state_dict[key]

print(current_tensor.dim())

        if current_tensor.dim() == 2:
            cv_state_dict[key] = current_tensor.squeeze()
        else:
            cv_state_dict[key] = current_tensor
    return cv_state_dict`

then can solve this with higher version pytorch