Closed zengxunli closed 1 year ago
并且在您的base_model.py
中,CAD所使用的base model是硬编码,使用的backbone是inv3,https://github.com/JacobYuan7/DIN-Group-Activity-Recognition-Benchmark/blob/main/base_model.py#:~:text=cfg.num_features_gcn-,self.backbone%3DMyInception_v3(transform_input%3DFalse%2Cpretrained,%23%20%20%20%20%20%20%20%20%20self.backbone%3DMyVGG16(pretrained%3DTrue),-if%20not%20self
当我将它改成ResNet18,并且重新开始一阶段训练时,它又产生了如下错误:
File "/home/disk1/zxl/model/DIN-Group-Activity-Recognition-Benchmark-main/./base_model.py", line 260, in forward
boxes_features_all=self.fc_emb_1(boxes_features_all) # B*T,MAX_N, NFB
File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (52x12800 and 26400x1024)
您好,感谢您开源的代码。VD的一二阶段训练都是顺利的,但是我在训练CAD时,第一阶段使用resnet18训练,但是在第二阶段遇到了权重载入的错误,模型配置文件完全是默认未改动的,您可以帮我看看这是为什么吗
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for MyRes18: Missing key(s) in state_dict: "features.0.weight", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.0.conv1.weight", "features.4.0.bn1.weight", "features.4.0.bn1.bias", "features.4.0.bn1.running_mean", "features.4.0.bn1.running_var", "features.4.0.conv2.weight", "features.4.0.bn2.weight", "features.4.0.bn2.bias", "features.4.0.bn2.running_mean", "features.4.0.bn2.running_var", "features.4.1.conv1.weight", "features.4.1.bn1.weight", "features.4.1.bn1.bias", "features.4.1.bn1.running_mean", "features.4.1.bn1.running_var", "features.4.1.conv2.weight", "features.4.1.bn2.weight", "features.4.1.bn2.bias", "features.4.1.bn2.running_mean", "features.4.1.bn2.running_var", "features.5.0.conv1.weight", "features.5.0.bn1.weight", "features.5.0.bn1.bias", "features.5.0.bn1.running_mean", "features.5.0.bn1.running_var", "features.5.0.conv2.weight", "features.5.0.bn2.weight", "features.5.0.bn2.bias", "features.5.0.bn2.running_mean", "features.5.0.bn2.running_var", "features.5.0.downsample.0.weight", "features.5.0.downsample.1.weight", "features.5.0.downsample.1.bias", "features.5.0.downsample.1.running_mean", "features.5.0.downsample.1.running_var", "features.5.1.conv1.weight", "features.5.1.bn1.weight", "features.5.1.bn1.bias", "features.5.1.bn1.running_mean", "features.5.1.bn1.running_var", "features.5.1.conv2.weight", "features.5.1.bn2.weight", "features.5.1.bn2.bias", "features.5.1.bn2.running_mean", "features.5.1.bn2.running_var", "features.6.0.conv1.weight", "features.6.0.bn1.weight", "features.6.0.bn1.bias", "features.6.0.bn1.running_mean", "features.6.0.bn1.running_var", "features.6.0.conv2.weight", "features.6.0.bn2.weight", "features.6.0.bn2.bias", "features.6.0.bn2.running_mean", "features.6.0.bn2.running_var", "features.6.0.downsample.0.weight", "features.6.0.downsample.1.weight", "features.6.0.downsample.1.bias", "features.6.0.downsample.1.running_mean", "features.6.0.downsample.1.running_var", "features.6.1.conv1.weight", "features.6.1.bn1.weight", "features.6.1.bn1.bias", "features.6.1.bn1.running_mean", "features.6.1.bn1.running_var", "features.6.1.conv2.weight", "features.6.1.bn2.weight", "features.6.1.bn2.bias", "features.6.1.bn2.running_mean", "features.6.1.bn2.running_var", "features.7.0.conv1.weight", "features.7.0.bn1.weight", "features.7.0.bn1.bias", "features.7.0.bn1.running_mean", "features.7.0.bn1.running_var", "features.7.0.conv2.weight", "features.7.0.bn2.weight", "features.7.0.bn2.bias", "features.7.0.bn2.running_mean", "features.7.0.bn2.running_var", "features.7.0.downsample.0.weight", "features.7.0.downsample.1.weight", "features.7.0.downsample.1.bias", "features.7.0.downsample.1.running_mean", "features.7.0.downsample.1.running_var", "features.7.1.conv1.weight", "features.7.1.bn1.weight", "features.7.1.bn1.bias", "features.7.1.bn1.running_mean", "features.7.1.bn1.running_var", "features.7.1.conv2.weight", "features.7.1.bn2.weight", "features.7.1.bn2.bias", "features.7.1.bn2.running_mean", "features.7.1.bn2.running_var". Unexpected key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight", "Conv2d_1a_3x3.bn.bias", "Conv2d_1a_3x3.bn.running_mean", "Conv2d_1a_3x3.bn.running_var", "Conv2d_1a_3x3.bn.num_batches_tracked", "Conv2d_2a_3x3.conv.weight", "Conv2d_2a_3x3.bn.weight", "Conv2d_2a_3x3.bn.bias", "Conv2d_2a_3x3.bn.running_mean", "Conv2d_2a_3x3.bn.running_var", "Conv2d_2a_3x3.bn.num_batches_tracked", "Conv2d_2b_3x3.conv.weight", "Conv2d_2b_3x3.bn.weight", "Conv2d_2b_3x3.bn.bias", "Conv2d_2b_3x3.bn.running_mean", "Conv2d_2b_3x3.bn.running_var", "Conv2d_2b_3x3.bn.num_batches_tracked", "Conv2d_3b_1x1.conv.weight", "Conv2d_3b_1x1.bn.weight", "Conv2d_3b_1x1.bn.bias", "Conv2d_3b_1x1.bn.running_mean", "Conv2d_3b_1x1.bn.running_var", "Conv2d_3b_1x1.bn.num_batches_tracked", "Conv2d_4a_3x3.conv.weight", "Conv2d_4a_3x3.bn.weight", "Conv2d_4a_3x3.bn.bias", "Conv2d_4a_3x3.bn.running_mean", "Conv2d_4a_3x3.bn.running_var", "Conv2d_4a_3x3.bn.num_batches_tracked", "Mixed_5b.branch1x1.conv.weight", "Mixed_5b.branch1x1.bn.weight", "Mixed_5b.branch1x1.bn.bias", "Mixed_5b.branch1x1.bn.running_mean", "Mixed_5b.branch1x1.bn.running_var", "Mixed_5b.branch1x1.bn.num_batches_tracked", "Mixed_5b.branch5x5_1.conv.weight", "Mixed_5b.branch5x5_1.bn.weight", "Mixed_5b.branch5x5_1.bn.bias", "Mixed_5b.branch5x5_1.bn.running_mean", "Mixed_5b.branch5x5_1.bn.running_var", "Mixed_5b.branch5x5_1.bn.num_batches_tracked", "Mixed_5b.branch5x5_2.conv.weight", "Mixed_5b.branch5x5_2.bn.weight", "Mixed_5b.branch5x5_2.bn.bias", "Mixed_5b.branch5x5_2.bn.running_mean", "Mixed_5b.branch5x5_2.bn.running_var", "Mixed_5b.branch5x5_2.bn.num_batches_tracked", "Mixed_5b.branch3x3dbl_1.conv.weight", "Mixed_5b.branch3x3dbl_1.bn.weight", "Mixed_5b.branch3x3dbl_1.bn.bias", "Mixed_5b.branch3x3dbl_1.bn.running_mean", "Mixed_5b.branch3x3dbl_1.bn.running_var", "Mixed_5b.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_5b.branch3x3dbl_2.conv.weight", "Mixed_5b.branch3x3dbl_2.bn.weight", "Mixed_5b.branch3x3dbl_2.bn.bias", "Mixed_5b.branch3x3dbl_2.bn.running_mean", "Mixed_5b.branch3x3dbl_2.bn.running_var", "Mixed_5b.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_5b.branch3x3dbl_3.conv.weight", "Mixed_5b.branch3x3dbl_3.bn.weight", "Mixed_5b.branch3x3dbl_3.bn.bias", "Mixed_5b.branch3x3dbl_3.bn.running_mean", "Mixed_5b.branch3x3dbl_3.bn.running_var", "Mixed_5b.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_5b.branch_pool.conv.weight", "Mixed_5b.branch_pool.bn.weight", "Mixed_5b.branch_pool.bn.bias", "Mixed_5b.branch_pool.bn.running_mean", "Mixed_5b.branch_pool.bn.running_var", "Mixed_5b.branch_pool.bn.num_batches_tracked", "Mixed_5c.branch1x1.conv.weight", "Mixed_5c.branch1x1.bn.weight", "Mixed_5c.branch1x1.bn.bias", "Mixed_5c.branch1x1.bn.running_mean", "Mixed_5c.branch1x1.bn.running_var", "Mixed_5c.branch1x1.bn.num_batches_tracked", "Mixed_5c.branch5x5_1.conv.weight", "Mixed_5c.branch5x5_1.bn.weight", "Mixed_5c.branch5x5_1.bn.bias", "Mixed_5c.branch5x5_1.bn.running_mean", "Mixed_5c.branch5x5_1.bn.running_var", "Mixed_5c.branch5x5_1.bn.num_batches_tracked", "Mixed_5c.branch5x5_2.conv.weight", "Mixed_5c.branch5x5_2.bn.weight", "Mixed_5c.branch5x5_2.bn.bias", "Mixed_5c.branch5x5_2.bn.running_mean", "Mixed_5c.branch5x5_2.bn.running_var", "Mixed_5c.branch5x5_2.bn.num_batches_tracked", "Mixed_5c.branch3x3dbl_1.conv.weight", "Mixed_5c.branch3x3dbl_1.bn.weight", "Mixed_5c.branch3x3dbl_1.bn.bias", "Mixed_5c.branch3x3dbl_1.bn.running_mean", "Mixed_5c.branch3x3dbl_1.bn.running_var", "Mixed_5c.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_5c.branch3x3dbl_2.conv.weight", "Mixed_5c.branch3x3dbl_2.bn.weight", "Mixed_5c.branch3x3dbl_2.bn.bias", "Mixed_5c.branch3x3dbl_2.bn.running_mean", "Mixed_5c.branch3x3dbl_2.bn.running_var", "Mixed_5c.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_5c.branch3x3dbl_3.conv.weight", "Mixed_5c.branch3x3dbl_3.bn.weight", "Mixed_5c.branch3x3dbl_3.bn.bias", "Mixed_5c.branch3x3dbl_3.bn.running_mean", "Mixed_5c.branch3x3dbl_3.bn.running_var", "Mixed_5c.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_5c.branch_pool.conv.weight", "Mixed_5c.branch_pool.bn.weight", "Mixed_5c.branch_pool.bn.bias", "Mixed_5c.branch_pool.bn.running_mean", "Mixed_5c.branch_pool.bn.running_var", "Mixed_5c.branch_pool.bn.num_batches_tracked", "Mixed_5d.branch1x1.conv.weight", "Mixed_5d.branch1x1.bn.weight", "Mixed_5d.branch1x1.bn.bias", "Mixed_5d.branch1x1.bn.running_mean", "Mixed_5d.branch1x1.bn.running_var", "Mixed_5d.branch1x1.bn.num_batches_tracked", "Mixed_5d.branch5x5_1.conv.weight", "Mixed_5d.branch5x5_1.bn.weight", "Mixed_5d.branch5x5_1.bn.bias", "Mixed_5d.branch5x5_1.bn.running_mean", "Mixed_5d.branch5x5_1.bn.running_var", "Mixed_5d.branch5x5_1.bn.num_batches_tracked", "Mixed_5d.branch5x5_2.conv.weight", "Mixed_5d.branch5x5_2.bn.weight", "Mixed_5d.branch5x5_2.bn.bias", "Mixed_5d.branch5x5_2.bn.running_mean", "Mixed_5d.branch5x5_2.bn.running_var", "Mixed_5d.branch5x5_2.bn.num_batches_tracked", "Mixed_5d.branch3x3dbl_1.conv.weight", "Mixed_5d.branch3x3dbl_1.bn.weight", "Mixed_5d.branch3x3dbl_1.bn.bias", "Mixed_5d.branch3x3dbl_1.bn.running_mean", "Mixed_5d.branch3x3dbl_1.bn.running_var", "Mixed_5d.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_5d.branch3x3dbl_2.conv.weight", "Mixed_5d.branch3x3dbl_2.bn.weight", "Mixed_5d.branch3x3dbl_2.bn.bias", "Mixed_5d.branch3x3dbl_2.bn.running_mean", "Mixed_5d.branch3x3dbl_2.bn.running_var", "Mixed_5d.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_5d.branch3x3dbl_3.conv.weight", "Mixed_5d.branch3x3dbl_3.bn.weight", "Mixed_5d.branch3x3dbl_3.bn.bias", "Mixed_5d.branch3x3dbl_3.bn.running_mean", "Mixed_5d.branch3x3dbl_3.bn.running_var", "Mixed_5d.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_5d.branch_pool.conv.weight", "Mixed_5d.branch_pool.bn.weight", "Mixed_5d.branch_pool.bn.bias", "Mixed_5d.branch_pool.bn.running_mean", "Mixed_5d.branch_pool.bn.running_var", "Mixed_5d.branch_pool.bn.num_batches_tracked", "Mixed_6a.branch3x3.conv.weight", "Mixed_6a.branch3x3.bn.weight", "Mixed_6a.branch3x3.bn.bias", "Mixed_6a.branch3x3.bn.running_mean", "Mixed_6a.branch3x3.bn.running_var", "Mixed_6a.branch3x3.bn.num_batches_tracked", "Mixed_6a.branch3x3dbl_1.conv.weight", "Mixed_6a.branch3x3dbl_1.bn.weight", "Mixed_6a.branch3x3dbl_1.bn.bias", "Mixed_6a.branch3x3dbl_1.bn.running_mean", "Mixed_6a.branch3x3dbl_1.bn.running_var", "Mixed_6a.branch3x3dbl_1.bn.num_batches_tracked", "Mixed_6a.branch3x3dbl_2.conv.weight", "Mixed_6a.branch3x3dbl_2.bn.weight", "Mixed_6a.branch3x3dbl_2.bn.bias", "Mixed_6a.branch3x3dbl_2.bn.running_mean", "Mixed_6a.branch3x3dbl_2.bn.running_var", "Mixed_6a.branch3x3dbl_2.bn.num_batches_tracked", "Mixed_6a.branch3x3dbl_3.conv.weight", "Mixed_6a.branch3x3dbl_3.bn.weight", "Mixed_6a.branch3x3dbl_3.bn.bias", "Mixed_6a.branch3x3dbl_3.bn.running_mean", "Mixed_6a.branch3x3dbl_3.bn.running_var", "Mixed_6a.branch3x3dbl_3.bn.num_batches_tracked", "Mixed_6b.branch1x1.conv.weight", "Mixed_6b.branch1x1.bn.weight", "Mixed_6b.branch1x1.bn.bias", "Mixed_6b.branch1x1.bn.running_mean", "Mixed_6b.branch1x1.bn.running_var", "Mixed_6b.branch1x1.bn.num_batches_tracked", "Mixed_6b.branch7x7_1.conv.weight", "Mixed_6b.branch7x7_1.bn.weight", "Mixed_6b.branch7x7_1.bn.bias", "Mixed_6b.branch7x7_1.bn.running_mean", "Mixed_6b.branch7x7_1.bn.running_var", "Mixed_6b.branch7x7_1.bn.num_batches_tracked", "Mixed_6b.branch7x7_2.conv.weight", "Mixed_6b.branch7x7_2.bn.weight", "Mixed_6b.branch7x7_2.bn.bias", "Mixed_6b.branch7x7_2.bn.running_mean", "Mixed_6b.branch7x7_2.bn.running_var", "Mixed_6b.branch7x7_2.bn.num_batches_tracked", "Mixed_6b.branch7x7_3.conv.weight", "Mixed_6b.branch7x7_3.bn.weight", "Mixed_6b.branch7x7_3.bn.bias", "Mixed_6b.branch7x7_3.bn.running_mean", "Mixed_6b.branch7x7_3.bn.running_var", "Mixed_6b.branch7x7_3.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_1.conv.weight", "Mixed_6b.branch7x7dbl_1.bn.weight", "Mixed_6b.branch7x7dbl_1.bn.bias", "Mixed_6b.branch7x7dbl_1.bn.running_mean", "Mixed_6b.branch7x7dbl_1.bn.running_var", "Mixed_6b.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_2.conv.weight", "Mixed_6b.branch7x7dbl_2.bn.weight", "Mixed_6b.branch7x7dbl_2.bn.bias", "Mixed_6b.branch7x7dbl_2.bn.running_mean", "Mixed_6b.branch7x7dbl_2.bn.running_var", "Mixed_6b.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_3.conv.weight", "Mixed_6b.branch7x7dbl_3.bn.weight", "Mixed_6b.branch7x7dbl_3.bn.bias", "Mixed_6b.branch7x7dbl_3.bn.running_mean", "Mixed_6b.branch7x7dbl_3.bn.running_var", "Mixed_6b.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_4.conv.weight", "Mixed_6b.branch7x7dbl_4.bn.weight", "Mixed_6b.branch7x7dbl_4.bn.bias", "Mixed_6b.branch7x7dbl_4.bn.running_mean", "Mixed_6b.branch7x7dbl_4.bn.running_var", "Mixed_6b.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6b.branch7x7dbl_5.conv.weight", "Mixed_6b.branch7x7dbl_5.bn.weight", "Mixed_6b.branch7x7dbl_5.bn.bias", "Mixed_6b.branch7x7dbl_5.bn.running_mean", "Mixed_6b.branch7x7dbl_5.bn.running_var", "Mixed_6b.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6b.branch_pool.conv.weight", "Mixed_6b.branch_pool.bn.weight", "Mixed_6b.branch_pool.bn.bias", "Mixed_6b.branch_pool.bn.running_mean", "Mixed_6b.branch_pool.bn.running_var", "Mixed_6b.branch_pool.bn.num_batches_tracked", "Mixed_6c.branch1x1.conv.weight", "Mixed_6c.branch1x1.bn.weight", "Mixed_6c.branch1x1.bn.bias", "Mixed_6c.branch1x1.bn.running_mean", "Mixed_6c.branch1x1.bn.running_var", "Mixed_6c.branch1x1.bn.num_batches_tracked", "Mixed_6c.branch7x7_1.conv.weight", "Mixed_6c.branch7x7_1.bn.weight", "Mixed_6c.branch7x7_1.bn.bias", "Mixed_6c.branch7x7_1.bn.running_mean", "Mixed_6c.branch7x7_1.bn.running_var", "Mixed_6c.branch7x7_1.bn.num_batches_tracked", "Mixed_6c.branch7x7_2.conv.weight", "Mixed_6c.branch7x7_2.bn.weight", "Mixed_6c.branch7x7_2.bn.bias", "Mixed_6c.branch7x7_2.bn.running_mean", "Mixed_6c.branch7x7_2.bn.running_var", "Mixed_6c.branch7x7_2.bn.num_batches_tracked", "Mixed_6c.branch7x7_3.conv.weight", "Mixed_6c.branch7x7_3.bn.weight", "Mixed_6c.branch7x7_3.bn.bias", "Mixed_6c.branch7x7_3.bn.running_mean", "Mixed_6c.branch7x7_3.bn.running_var", "Mixed_6c.branch7x7_3.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_1.conv.weight", "Mixed_6c.branch7x7dbl_1.bn.weight", "Mixed_6c.branch7x7dbl_1.bn.bias", "Mixed_6c.branch7x7dbl_1.bn.running_mean", "Mixed_6c.branch7x7dbl_1.bn.running_var", "Mixed_6c.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_2.conv.weight", "Mixed_6c.branch7x7dbl_2.bn.weight", "Mixed_6c.branch7x7dbl_2.bn.bias", "Mixed_6c.branch7x7dbl_2.bn.running_mean", "Mixed_6c.branch7x7dbl_2.bn.running_var", "Mixed_6c.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_3.conv.weight", "Mixed_6c.branch7x7dbl_3.bn.weight", "Mixed_6c.branch7x7dbl_3.bn.bias", "Mixed_6c.branch7x7dbl_3.bn.running_mean", "Mixed_6c.branch7x7dbl_3.bn.running_var", "Mixed_6c.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_4.conv.weight", "Mixed_6c.branch7x7dbl_4.bn.weight", "Mixed_6c.branch7x7dbl_4.bn.bias", "Mixed_6c.branch7x7dbl_4.bn.running_mean", "Mixed_6c.branch7x7dbl_4.bn.running_var", "Mixed_6c.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6c.branch7x7dbl_5.conv.weight", "Mixed_6c.branch7x7dbl_5.bn.weight", "Mixed_6c.branch7x7dbl_5.bn.bias", "Mixed_6c.branch7x7dbl_5.bn.running_mean", "Mixed_6c.branch7x7dbl_5.bn.running_var", "Mixed_6c.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6c.branch_pool.conv.weight", "Mixed_6c.branch_pool.bn.weight", "Mixed_6c.branch_pool.bn.bias", "Mixed_6c.branch_pool.bn.running_mean", "Mixed_6c.branch_pool.bn.running_var", "Mixed_6c.branch_pool.bn.num_batches_tracked", "Mixed_6d.branch1x1.conv.weight", "Mixed_6d.branch1x1.bn.weight", "Mixed_6d.branch1x1.bn.bias", "Mixed_6d.branch1x1.bn.running_mean", "Mixed_6d.branch1x1.bn.running_var", "Mixed_6d.branch1x1.bn.num_batches_tracked", "Mixed_6d.branch7x7_1.conv.weight", "Mixed_6d.branch7x7_1.bn.weight", "Mixed_6d.branch7x7_1.bn.bias", "Mixed_6d.branch7x7_1.bn.running_mean", "Mixed_6d.branch7x7_1.bn.running_var", "Mixed_6d.branch7x7_1.bn.num_batches_tracked", "Mixed_6d.branch7x7_2.conv.weight", "Mixed_6d.branch7x7_2.bn.weight", "Mixed_6d.branch7x7_2.bn.bias", "Mixed_6d.branch7x7_2.bn.running_mean", "Mixed_6d.branch7x7_2.bn.running_var", "Mixed_6d.branch7x7_2.bn.num_batches_tracked", "Mixed_6d.branch7x7_3.conv.weight", "Mixed_6d.branch7x7_3.bn.weight", "Mixed_6d.branch7x7_3.bn.bias", "Mixed_6d.branch7x7_3.bn.running_mean", "Mixed_6d.branch7x7_3.bn.running_var", "Mixed_6d.branch7x7_3.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_1.conv.weight", "Mixed_6d.branch7x7dbl_1.bn.weight", "Mixed_6d.branch7x7dbl_1.bn.bias", "Mixed_6d.branch7x7dbl_1.bn.running_mean", "Mixed_6d.branch7x7dbl_1.bn.running_var", "Mixed_6d.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_2.conv.weight", "Mixed_6d.branch7x7dbl_2.bn.weight", "Mixed_6d.branch7x7dbl_2.bn.bias", "Mixed_6d.branch7x7dbl_2.bn.running_mean", "Mixed_6d.branch7x7dbl_2.bn.running_var", "Mixed_6d.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_3.conv.weight", "Mixed_6d.branch7x7dbl_3.bn.weight", "Mixed_6d.branch7x7dbl_3.bn.bias", "Mixed_6d.branch7x7dbl_3.bn.running_mean", "Mixed_6d.branch7x7dbl_3.bn.running_var", "Mixed_6d.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_4.conv.weight", "Mixed_6d.branch7x7dbl_4.bn.weight", "Mixed_6d.branch7x7dbl_4.bn.bias", "Mixed_6d.branch7x7dbl_4.bn.running_mean", "Mixed_6d.branch7x7dbl_4.bn.running_var", "Mixed_6d.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6d.branch7x7dbl_5.conv.weight", "Mixed_6d.branch7x7dbl_5.bn.weight", "Mixed_6d.branch7x7dbl_5.bn.bias", "Mixed_6d.branch7x7dbl_5.bn.running_mean", "Mixed_6d.branch7x7dbl_5.bn.running_var", "Mixed_6d.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6d.branch_pool.conv.weight", "Mixed_6d.branch_pool.bn.weight", "Mixed_6d.branch_pool.bn.bias", "Mixed_6d.branch_pool.bn.running_mean", "Mixed_6d.branch_pool.bn.running_var", "Mixed_6d.branch_pool.bn.num_batches_tracked", "Mixed_6e.branch1x1.conv.weight", "Mixed_6e.branch1x1.bn.weight", "Mixed_6e.branch1x1.bn.bias", "Mixed_6e.branch1x1.bn.running_mean", "Mixed_6e.branch1x1.bn.running_var", "Mixed_6e.branch1x1.bn.num_batches_tracked", "Mixed_6e.branch7x7_1.conv.weight", "Mixed_6e.branch7x7_1.bn.weight", "Mixed_6e.branch7x7_1.bn.bias", "Mixed_6e.branch7x7_1.bn.running_mean", "Mixed_6e.branch7x7_1.bn.running_var", "Mixed_6e.branch7x7_1.bn.num_batches_tracked", "Mixed_6e.branch7x7_2.conv.weight", "Mixed_6e.branch7x7_2.bn.weight", "Mixed_6e.branch7x7_2.bn.bias", "Mixed_6e.branch7x7_2.bn.running_mean", "Mixed_6e.branch7x7_2.bn.running_var", "Mixed_6e.branch7x7_2.bn.num_batches_tracked", "Mixed_6e.branch7x7_3.conv.weight", "Mixed_6e.branch7x7_3.bn.weight", "Mixed_6e.branch7x7_3.bn.bias", "Mixed_6e.branch7x7_3.bn.running_mean", "Mixed_6e.branch7x7_3.bn.running_var", "Mixed_6e.branch7x7_3.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_1.conv.weight", "Mixed_6e.branch7x7dbl_1.bn.weight", "Mixed_6e.branch7x7dbl_1.bn.bias", "Mixed_6e.branch7x7dbl_1.bn.running_mean", "Mixed_6e.branch7x7dbl_1.bn.running_var", "Mixed_6e.branch7x7dbl_1.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_2.conv.weight", "Mixed_6e.branch7x7dbl_2.bn.weight", "Mixed_6e.branch7x7dbl_2.bn.bias", "Mixed_6e.branch7x7dbl_2.bn.running_mean", "Mixed_6e.branch7x7dbl_2.bn.running_var", "Mixed_6e.branch7x7dbl_2.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_3.conv.weight", "Mixed_6e.branch7x7dbl_3.bn.weight", "Mixed_6e.branch7x7dbl_3.bn.bias", "Mixed_6e.branch7x7dbl_3.bn.running_mean", "Mixed_6e.branch7x7dbl_3.bn.running_var", "Mixed_6e.branch7x7dbl_3.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_4.conv.weight", "Mixed_6e.branch7x7dbl_4.bn.weight", "Mixed_6e.branch7x7dbl_4.bn.bias", "Mixed_6e.branch7x7dbl_4.bn.running_mean", "Mixed_6e.branch7x7dbl_4.bn.running_var", "Mixed_6e.branch7x7dbl_4.bn.num_batches_tracked", "Mixed_6e.branch7x7dbl_5.conv.weight", "Mixed_6e.branch7x7dbl_5.bn.weight", "Mixed_6e.branch7x7dbl_5.bn.bias", "Mixed_6e.branch7x7dbl_5.bn.running_mean", "Mixed_6e.branch7x7dbl_5.bn.running_var", "Mixed_6e.branch7x7dbl_5.bn.num_batches_tracked", "Mixed_6e.branch_pool.conv.weight", "Mixed_6e.branch_pool.bn.weight", "Mixed_6e.branch_pool.bn.bias", "Mixed_6e.branch_pool.bn.running_mean", "Mixed_6e.branch_pool.bn.running_var", "Mixed_6e.branch_pool.bn.num_batches_tracked".
I think the model and the parameters you are using seem different. One is for ResNet-18 and the other is for Inception-v3.
是的,读过代码后发现,第二阶段载入的模型是resnet18,第一阶段的模型是固定了的inv3 https://github.com/JacobYuan7/DIN-Group-Activity-Recognition-Benchmark/blob/4648310a42ca7b66013da9d623e9f856a483f30c/base_model.py#L158
并且在您的
base_model.py
中,CAD所使用的base model是硬编码,使用的backbone是inv3,https://github.com/JacobYuan7/DIN-Group-Activity-Recognition-Benchmark/blob/main/base_model.py#:~:text=cfg.num_features_gcn-,self.backbone%3DMyInception_v3(transform_input%3DFalse%2Cpretrained,%23%20%20%20%20%20%20%20%20%20self.backbone%3DMyVGG16(pretrained%3DTrue),-if%20not%20self当我将它改成ResNet18,并且重新开始一阶段训练时,它又产生了如下错误:
File "/home/disk1/zxl/model/DIN-Group-Activity-Recognition-Benchmark-main/./base_model.py", line 260, in forward boxes_features_all=self.fc_emb_1(boxes_features_all) # B*T,MAX_N, NFB File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (52x12800 and 26400x1024)
并且在您的
base_model.py
中,CAD所使用的base model是硬编码,使用的backbone是inv3,https://github.com/JacobYuan7/DIN-Group-Activity-Recognition-Benchmark/blob/main/base_model.py#:~:text=cfg.num_features_gcn-,self.backbone%3DMyInception_v3(transform_input%3DFalse%2Cpretrained,%23%20%20%20%20%20%20%20%20%20self.backbone%3DMyVGG16(pretrained%3DTrue),-if%20not%20self当我将它改成ResNet18,并且重新开始一阶段训练时,它又产生了如下错误:
File "/home/disk1/zxl/model/DIN-Group-Activity-Recognition-Benchmark-main/./base_model.py", line 260, in forward boxes_features_all=self.fc_emb_1(boxes_features_all) # B*T,MAX_N, NFB File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (52x12800 and 26400x1024)
You may need to add codes for stage1 to this part by copying from base_volleyball. It should work.
If you are training with resnet-18, you also need to change the out_size
in the script to align it with the actual output size from resnet-18.
并且在您的
base_model.py
中,CAD所使用的base model是硬编码,使用的backbone是inv3,https://github.com/JacobYuan7/DIN-Group-Activity-Recognition-Benchmark/blob/main/base_model.py#:~:text=cfg.num_features_gcn-,self.backbone%3DMyInception_v3(transform_input%3DFalse%2Cpretrained,%23%20%20%20%20%20%20%20%20%20self.backbone%3DMyVGG16(pretrained%3DTrue),-if%20not%20self 当我将它改成ResNet18,并且重新开始一阶段训练时,它又产生了如下错误:File "/home/disk1/zxl/model/DIN-Group-Activity-Recognition-Benchmark-main/./base_model.py", line 260, in forward boxes_features_all=self.fc_emb_1(boxes_features_all) # B*T,MAX_N, NFB File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/home/disk1/zxl/anaconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (52x12800 and 26400x1024)
You may need to add codes for stage1 to this part by copying from base_volleyball. It should work. If you are training with resnet-18, you also need to change the
out_size
in the script to align it with the actual output size from resnet-18.
嗯嗯,好的,感谢您的回复,我这就去试试
您好,我完成了以上的代码修改,resnet18作为backbone的CAD一阶段训练最佳结果:93.02%,但是第二阶段训练过拟合严重,训练精度在15epoch左右基本达到了100%,而测试最佳:89.41%。这是怎么回事呢,超参数我只因为显存不够而减小了batch_size。顺便要说的是模型在VD上一、二阶段的训练结果是和论文指标一致的,并且同样减小了batch_size。期待您的回复
您好,我完成了以上的代码修改,resnet18作为backbone的CAD一阶段训练最佳结果:93.02%,但是第二阶段训练过拟合严重,训练精度在15epoch左右基本达到了100%,而测试最佳:89.41%。这是怎么回事呢,超参数我只因为显存不够而减小了batch_size。顺便要说的是模型在VD上一、二阶段的训练结果是和论文指标一致的,并且同样减小了batch_size。期待您的回复
CAD is a dataset that easily overfits and encounters performance fluctuation because of the small dataset size. Two-stage pre-training makes this situation even worse. Usually, you need to use early stopping to achieve high performance. The performance in the 2nd stage should be better than the 1st stage.
您好,感谢您开源的代码。VD的一二阶段训练都是顺利的,但是我在训练CAD时,第一阶段使用resnet18训练,但是在第二阶段遇到了权重载入的错误,模型配置文件完全是默认未改动的,您可以帮我看看这是为什么吗