Closed aliencaocao closed 1 year ago
您好,方便把您修改后的self.statedict()[i].copy(param_dict[i])所在函数的代码贴一下么?
make model.py line 190
convert_weights = True if pretrain_choice == 'imagenet' else False
self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](img_size=cfg.INPUT.SIZE_TRAIN, drop_path_rate=cfg.MODEL.DROP_PATH, drop_rate= cfg.MODEL.DROP_OUT,attn_drop_rate=cfg.MODEL.ATT_DROP_RATE, pretrained=model_path, convert_weights=convert_weights, semantic_weight=semantic_weight)
if model_path != '':
print(self.state_dict().keys())
param_dict = torch.load(model_path)
for i in param_dict:
if 'classifier' in i:
continue
self.state_dict()[i].copy_(param_dict[i])
self.in_planes = self.base.num_features[-1]
self.num_classes = num_classes
self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
odict_keys(['base.patch_embed.projection.weight', 'base.patch_embed.projection.bias', 'base.patch_embed.norm.weight', 'base.patch_embed.norm.bias', 'base.stages.0.blocks.0.norm1.weight', 'base.stages.0.blocks.0.norm1.bias', 'base.stages.0.blocks.0.attn.w_msa.relative_position_bias_table', 'base.stages.0.blocks.0.attn.w_msa.relative_position_index', 'base.stages.0.blocks.0.attn.w_msa.qkv.weight', 'base.stages.0.blocks.0.attn.w_msa.qkv.bias', 'base.stages.0.blocks.0.attn.w_msa.proj.weight', 'base.stages.0.blocks.0.attn.w_msa.proj.bias', 'base.stages.0.blocks.0.norm2.weight', 'base.stages.0.blocks.0.norm2.bias', 'base.stages.0.blocks.0.ffn.layers.0.0.weight', 'base.stages.0.blocks.0.ffn.layers.0.0.bias', 'base.stages.0.blocks.0.ffn.layers.1.weight', 'base.stages.0.blocks.0.ffn.layers.1.bias', 'base.stages.0.blocks.1.norm1.weight', 'base.stages.0.blocks.1.norm1.bias', 'base.stages.0.blocks.1.attn.w_msa.relative_position_bias_table', 'base.stages.0.blocks.1.attn.w_msa.relative_position_index', 'base.stages.0.blocks.1.attn.w_msa.qkv.weight', 'base.stages.0.blocks.1.attn.w_msa.qkv.bias', 'base.stages.0.blocks.1.attn.w_msa.proj.weight', 'base.stages.0.blocks.1.attn.w_msa.proj.bias', 'base.stages.0.blocks.1.norm2.weight', 'base.stages.0.blocks.1.norm2.bias', 'base.stages.0.blocks.1.ffn.layers.0.0.weight', 'base.stages.0.blocks.1.ffn.layers.0.0.bias', 'base.stages.0.blocks.1.ffn.layers.1.weight', 'base.stages.0.blocks.1.ffn.layers.1.bias', 'base.stages.0.downsample.norm.weight', 'base.stages.0.downsample.norm.bias', 'base.stages.0.downsample.reduction.weight', 'base.stages.1.blocks.0.norm1.weight', 'base.stages.1.blocks.0.norm1.bias', 'base.stages.1.blocks.0.attn.w_msa.relative_position_bias_table', 'base.stages.1.blocks.0.attn.w_msa.relative_position_index', 'base.stages.1.blocks.0.attn.w_msa.qkv.weight', 'base.stages.1.blocks.0.attn.w_msa.qkv.bias', 'base.stages.1.blocks.0.attn.w_msa.proj.weight', 'base.stages.1.blocks.0.attn.w_msa.proj.bias', 'base.stages.1.blocks.0.norm2.weight', 'base.stages.1.blocks.0.norm2.bias', 'base.stages.1.blocks.0.ffn.layers.0.0.weight', 'base.stages.1.blocks.0.ffn.layers.0.0.bias', 'base.stages.1.blocks.0.ffn.layers.1.weight', 'base.stages.1.blocks.0.ffn.layers.1.bias', 'base.stages.1.blocks.1.norm1.weight', 'base.stages.1.blocks.1.norm1.bias', 'base.stages.1.blocks.1.attn.w_msa.relative_position_bias_table', 'base.stages.1.blocks.1.attn.w_msa.relative_position_index', 'base.stages.1.blocks.1.attn.w_msa.qkv.weight', 'base.stages.1.blocks.1.attn.w_msa.qkv.bias', 'base.stages.1.blocks.1.attn.w_msa.proj.weight', 'base.stages.1.blocks.1.attn.w_msa.proj.bias', 'base.stages.1.blocks.1.norm2.weight', 'base.stages.1.blocks.1.norm2.bias', 'base.stages.1.blocks.1.ffn.layers.0.0.weight', 'base.stages.1.blocks.1.ffn.layers.0.0.bias', 'base.stages.1.blocks.1.ffn.layers.1.weight', 'base.stages.1.blocks.1.ffn.layers.1.bias', 'base.stages.1.downsample.norm.weight', 'base.stages.1.downsample.norm.bias', 'base.stages.1.downsample.reduction.weight', 'base.stages.2.blocks.0.norm1.weight', 'base.stages.2.blocks.0.norm1.bias', 'base.stages.2.blocks.0.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.0.attn.w_msa.relative_position_index', 'base.stages.2.blocks.0.attn.w_msa.qkv.weight', 'base.stages.2.blocks.0.attn.w_msa.qkv.bias', 'base.stages.2.blocks.0.attn.w_msa.proj.weight', 'base.stages.2.blocks.0.attn.w_msa.proj.bias', 'base.stages.2.blocks.0.norm2.weight', 'base.stages.2.blocks.0.norm2.bias', 'base.stages.2.blocks.0.ffn.layers.0.0.weight', 'base.stages.2.blocks.0.ffn.layers.0.0.bias', 'base.stages.2.blocks.0.ffn.layers.1.weight', 'base.stages.2.blocks.0.ffn.layers.1.bias', 'base.stages.2.blocks.1.norm1.weight', 'base.stages.2.blocks.1.norm1.bias', 'base.stages.2.blocks.1.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.1.attn.w_msa.relative_position_index', 'base.stages.2.blocks.1.attn.w_msa.qkv.weight', 'base.stages.2.blocks.1.attn.w_msa.qkv.bias', 'base.stages.2.blocks.1.attn.w_msa.proj.weight', 'base.stages.2.blocks.1.attn.w_msa.proj.bias', 'base.stages.2.blocks.1.norm2.weight', 'base.stages.2.blocks.1.norm2.bias', 'base.stages.2.blocks.1.ffn.layers.0.0.weight', 'base.stages.2.blocks.1.ffn.layers.0.0.bias', 'base.stages.2.blocks.1.ffn.layers.1.weight', 'base.stages.2.blocks.1.ffn.layers.1.bias', 'base.stages.2.blocks.2.norm1.weight', 'base.stages.2.blocks.2.norm1.bias', 'base.stages.2.blocks.2.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.2.attn.w_msa.relative_position_index', 'base.stages.2.blocks.2.attn.w_msa.qkv.weight', 'base.stages.2.blocks.2.attn.w_msa.qkv.bias', 'base.stages.2.blocks.2.attn.w_msa.proj.weight', 'base.stages.2.blocks.2.attn.w_msa.proj.bias', 'base.stages.2.blocks.2.norm2.weight', 'base.stages.2.blocks.2.norm2.bias', 'base.stages.2.blocks.2.ffn.layers.0.0.weight', 'base.stages.2.blocks.2.ffn.layers.0.0.bias', 'base.stages.2.blocks.2.ffn.layers.1.weight', 'base.stages.2.blocks.2.ffn.layers.1.bias', 'base.stages.2.blocks.3.norm1.weight', 'base.stages.2.blocks.3.norm1.bias', 'base.stages.2.blocks.3.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.3.attn.w_msa.relative_position_index', 'base.stages.2.blocks.3.attn.w_msa.qkv.weight', 'base.stages.2.blocks.3.attn.w_msa.qkv.bias', 'base.stages.2.blocks.3.attn.w_msa.proj.weight', 'base.stages.2.blocks.3.attn.w_msa.proj.bias', 'base.stages.2.blocks.3.norm2.weight', 'base.stages.2.blocks.3.norm2.bias', 'base.stages.2.blocks.3.ffn.layers.0.0.weight', 'base.stages.2.blocks.3.ffn.layers.0.0.bias', 'base.stages.2.blocks.3.ffn.layers.1.weight', 'base.stages.2.blocks.3.ffn.layers.1.bias', 'base.stages.2.blocks.4.norm1.weight', 'base.stages.2.blocks.4.norm1.bias', 'base.stages.2.blocks.4.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.4.attn.w_msa.relative_position_index', 'base.stages.2.blocks.4.attn.w_msa.qkv.weight', 'base.stages.2.blocks.4.attn.w_msa.qkv.bias', 'base.stages.2.blocks.4.attn.w_msa.proj.weight', 'base.stages.2.blocks.4.attn.w_msa.proj.bias', 'base.stages.2.blocks.4.norm2.weight', 'base.stages.2.blocks.4.norm2.bias', 'base.stages.2.blocks.4.ffn.layers.0.0.weight', 'base.stages.2.blocks.4.ffn.layers.0.0.bias', 'base.stages.2.blocks.4.ffn.layers.1.weight', 'base.stages.2.blocks.4.ffn.layers.1.bias', 'base.stages.2.blocks.5.norm1.weight', 'base.stages.2.blocks.5.norm1.bias', 'base.stages.2.blocks.5.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.5.attn.w_msa.relative_position_index', 'base.stages.2.blocks.5.attn.w_msa.qkv.weight', 'base.stages.2.blocks.5.attn.w_msa.qkv.bias', 'base.stages.2.blocks.5.attn.w_msa.proj.weight', 'base.stages.2.blocks.5.attn.w_msa.proj.bias', 'base.stages.2.blocks.5.norm2.weight', 'base.stages.2.blocks.5.norm2.bias', 'base.stages.2.blocks.5.ffn.layers.0.0.weight', 'base.stages.2.blocks.5.ffn.layers.0.0.bias', 'base.stages.2.blocks.5.ffn.layers.1.weight', 'base.stages.2.blocks.5.ffn.layers.1.bias', 'base.stages.2.blocks.6.norm1.weight', 'base.stages.2.blocks.6.norm1.bias', 'base.stages.2.blocks.6.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.6.attn.w_msa.relative_position_index', 'base.stages.2.blocks.6.attn.w_msa.qkv.weight', 'base.stages.2.blocks.6.attn.w_msa.qkv.bias', 'base.stages.2.blocks.6.attn.w_msa.proj.weight', 'base.stages.2.blocks.6.attn.w_msa.proj.bias', 'base.stages.2.blocks.6.norm2.weight', 'base.stages.2.blocks.6.norm2.bias', 'base.stages.2.blocks.6.ffn.layers.0.0.weight', 'base.stages.2.blocks.6.ffn.layers.0.0.bias', 'base.stages.2.blocks.6.ffn.layers.1.weight', 'base.stages.2.blocks.6.ffn.layers.1.bias', 'base.stages.2.blocks.7.norm1.weight', 'base.stages.2.blocks.7.norm1.bias', 'base.stages.2.blocks.7.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.7.attn.w_msa.relative_position_index', 'base.stages.2.blocks.7.attn.w_msa.qkv.weight', 'base.stages.2.blocks.7.attn.w_msa.qkv.bias', 'base.stages.2.blocks.7.attn.w_msa.proj.weight', 'base.stages.2.blocks.7.attn.w_msa.proj.bias', 'base.stages.2.blocks.7.norm2.weight', 'base.stages.2.blocks.7.norm2.bias', 'base.stages.2.blocks.7.ffn.layers.0.0.weight', 'base.stages.2.blocks.7.ffn.layers.0.0.bias', 'base.stages.2.blocks.7.ffn.layers.1.weight', 'base.stages.2.blocks.7.ffn.layers.1.bias', 'base.stages.2.blocks.8.norm1.weight', 'base.stages.2.blocks.8.norm1.bias', 'base.stages.2.blocks.8.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.8.attn.w_msa.relative_position_index', 'base.stages.2.blocks.8.attn.w_msa.qkv.weight', 'base.stages.2.blocks.8.attn.w_msa.qkv.bias', 'base.stages.2.blocks.8.attn.w_msa.proj.weight', 'base.stages.2.blocks.8.attn.w_msa.proj.bias', 'base.stages.2.blocks.8.norm2.weight', 'base.stages.2.blocks.8.norm2.bias', 'base.stages.2.blocks.8.ffn.layers.0.0.weight', 'base.stages.2.blocks.8.ffn.layers.0.0.bias', 'base.stages.2.blocks.8.ffn.layers.1.weight', 'base.stages.2.blocks.8.ffn.layers.1.bias', 'base.stages.2.blocks.9.norm1.weight', 'base.stages.2.blocks.9.norm1.bias', 'base.stages.2.blocks.9.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.9.attn.w_msa.relative_position_index', 'base.stages.2.blocks.9.attn.w_msa.qkv.weight', 'base.stages.2.blocks.9.attn.w_msa.qkv.bias', 'base.stages.2.blocks.9.attn.w_msa.proj.weight', 'base.stages.2.blocks.9.attn.w_msa.proj.bias', 'base.stages.2.blocks.9.norm2.weight', 'base.stages.2.blocks.9.norm2.bias', 'base.stages.2.blocks.9.ffn.layers.0.0.weight', 'base.stages.2.blocks.9.ffn.layers.0.0.bias', 'base.stages.2.blocks.9.ffn.layers.1.weight', 'base.stages.2.blocks.9.ffn.layers.1.bias', 'base.stages.2.blocks.10.norm1.weight', 'base.stages.2.blocks.10.norm1.bias', 'base.stages.2.blocks.10.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.10.attn.w_msa.relative_position_index', 'base.stages.2.blocks.10.attn.w_msa.qkv.weight', 'base.stages.2.blocks.10.attn.w_msa.qkv.bias', 'base.stages.2.blocks.10.attn.w_msa.proj.weight', 'base.stages.2.blocks.10.attn.w_msa.proj.bias', 'base.stages.2.blocks.10.norm2.weight', 'base.stages.2.blocks.10.norm2.bias', 'base.stages.2.blocks.10.ffn.layers.0.0.weight', 'base.stages.2.blocks.10.ffn.layers.0.0.bias', 'base.stages.2.blocks.10.ffn.layers.1.weight', 'base.stages.2.blocks.10.ffn.layers.1.bias', 'base.stages.2.blocks.11.norm1.weight', 'base.stages.2.blocks.11.norm1.bias', 'base.stages.2.blocks.11.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.11.attn.w_msa.relative_position_index', 'base.stages.2.blocks.11.attn.w_msa.qkv.weight', 'base.stages.2.blocks.11.attn.w_msa.qkv.bias', 'base.stages.2.blocks.11.attn.w_msa.proj.weight', 'base.stages.2.blocks.11.attn.w_msa.proj.bias', 'base.stages.2.blocks.11.norm2.weight', 'base.stages.2.blocks.11.norm2.bias', 'base.stages.2.blocks.11.ffn.layers.0.0.weight', 'base.stages.2.blocks.11.ffn.layers.0.0.bias', 'base.stages.2.blocks.11.ffn.layers.1.weight', 'base.stages.2.blocks.11.ffn.layers.1.bias', 'base.stages.2.blocks.12.norm1.weight', 'base.stages.2.blocks.12.norm1.bias', 'base.stages.2.blocks.12.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.12.attn.w_msa.relative_position_index', 'base.stages.2.blocks.12.attn.w_msa.qkv.weight', 'base.stages.2.blocks.12.attn.w_msa.qkv.bias', 'base.stages.2.blocks.12.attn.w_msa.proj.weight', 'base.stages.2.blocks.12.attn.w_msa.proj.bias', 'base.stages.2.blocks.12.norm2.weight', 'base.stages.2.blocks.12.norm2.bias', 'base.stages.2.blocks.12.ffn.layers.0.0.weight', 'base.stages.2.blocks.12.ffn.layers.0.0.bias', 'base.stages.2.blocks.12.ffn.layers.1.weight', 'base.stages.2.blocks.12.ffn.layers.1.bias', 'base.stages.2.blocks.13.norm1.weight', 'base.stages.2.blocks.13.norm1.bias', 'base.stages.2.blocks.13.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.13.attn.w_msa.relative_position_index', 'base.stages.2.blocks.13.attn.w_msa.qkv.weight', 'base.stages.2.blocks.13.attn.w_msa.qkv.bias', 'base.stages.2.blocks.13.attn.w_msa.proj.weight', 'base.stages.2.blocks.13.attn.w_msa.proj.bias', 'base.stages.2.blocks.13.norm2.weight', 'base.stages.2.blocks.13.norm2.bias', 'base.stages.2.blocks.13.ffn.layers.0.0.weight', 'base.stages.2.blocks.13.ffn.layers.0.0.bias', 'base.stages.2.blocks.13.ffn.layers.1.weight', 'base.stages.2.blocks.13.ffn.layers.1.bias', 'base.stages.2.blocks.14.norm1.weight', 'base.stages.2.blocks.14.norm1.bias', 'base.stages.2.blocks.14.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.14.attn.w_msa.relative_position_index', 'base.stages.2.blocks.14.attn.w_msa.qkv.weight', 'base.stages.2.blocks.14.attn.w_msa.qkv.bias', 'base.stages.2.blocks.14.attn.w_msa.proj.weight', 'base.stages.2.blocks.14.attn.w_msa.proj.bias', 'base.stages.2.blocks.14.norm2.weight', 'base.stages.2.blocks.14.norm2.bias', 'base.stages.2.blocks.14.ffn.layers.0.0.weight', 'base.stages.2.blocks.14.ffn.layers.0.0.bias', 'base.stages.2.blocks.14.ffn.layers.1.weight', 'base.stages.2.blocks.14.ffn.layers.1.bias', 'base.stages.2.blocks.15.norm1.weight', 'base.stages.2.blocks.15.norm1.bias', 'base.stages.2.blocks.15.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.15.attn.w_msa.relative_position_index', 'base.stages.2.blocks.15.attn.w_msa.qkv.weight', 'base.stages.2.blocks.15.attn.w_msa.qkv.bias', 'base.stages.2.blocks.15.attn.w_msa.proj.weight', 'base.stages.2.blocks.15.attn.w_msa.proj.bias', 'base.stages.2.blocks.15.norm2.weight', 'base.stages.2.blocks.15.norm2.bias', 'base.stages.2.blocks.15.ffn.layers.0.0.weight', 'base.stages.2.blocks.15.ffn.layers.0.0.bias', 'base.stages.2.blocks.15.ffn.layers.1.weight', 'base.stages.2.blocks.15.ffn.layers.1.bias', 'base.stages.2.blocks.16.norm1.weight', 'base.stages.2.blocks.16.norm1.bias', 'base.stages.2.blocks.16.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.16.attn.w_msa.relative_position_index', 'base.stages.2.blocks.16.attn.w_msa.qkv.weight', 'base.stages.2.blocks.16.attn.w_msa.qkv.bias', 'base.stages.2.blocks.16.attn.w_msa.proj.weight', 'base.stages.2.blocks.16.attn.w_msa.proj.bias', 'base.stages.2.blocks.16.norm2.weight', 'base.stages.2.blocks.16.norm2.bias', 'base.stages.2.blocks.16.ffn.layers.0.0.weight', 'base.stages.2.blocks.16.ffn.layers.0.0.bias', 'base.stages.2.blocks.16.ffn.layers.1.weight', 'base.stages.2.blocks.16.ffn.layers.1.bias', 'base.stages.2.blocks.17.norm1.weight', 'base.stages.2.blocks.17.norm1.bias', 'base.stages.2.blocks.17.attn.w_msa.relative_position_bias_table', 'base.stages.2.blocks.17.attn.w_msa.relative_position_index', 'base.stages.2.blocks.17.attn.w_msa.qkv.weight', 'base.stages.2.blocks.17.attn.w_msa.qkv.bias', 'base.stages.2.blocks.17.attn.w_msa.proj.weight', 'base.stages.2.blocks.17.attn.w_msa.proj.bias', 'base.stages.2.blocks.17.norm2.weight', 'base.stages.2.blocks.17.norm2.bias', 'base.stages.2.blocks.17.ffn.layers.0.0.weight', 'base.stages.2.blocks.17.ffn.layers.0.0.bias', 'base.stages.2.blocks.17.ffn.layers.1.weight', 'base.stages.2.blocks.17.ffn.layers.1.bias', 'base.stages.2.downsample.norm.weight', 'base.stages.2.downsample.norm.bias', 'base.stages.2.downsample.reduction.weight', 'base.stages.3.blocks.0.norm1.weight', 'base.stages.3.blocks.0.norm1.bias', 'base.stages.3.blocks.0.attn.w_msa.relative_position_bias_table', 'base.stages.3.blocks.0.attn.w_msa.relative_position_index', 'base.stages.3.blocks.0.attn.w_msa.qkv.weight', 'base.stages.3.blocks.0.attn.w_msa.qkv.bias', 'base.stages.3.blocks.0.attn.w_msa.proj.weight', 'base.stages.3.blocks.0.attn.w_msa.proj.bias', 'base.stages.3.blocks.0.norm2.weight', 'base.stages.3.blocks.0.norm2.bias', 'base.stages.3.blocks.0.ffn.layers.0.0.weight', 'base.stages.3.blocks.0.ffn.layers.0.0.bias', 'base.stages.3.blocks.0.ffn.layers.1.weight', 'base.stages.3.blocks.0.ffn.layers.1.bias', 'base.stages.3.blocks.1.norm1.weight', 'base.stages.3.blocks.1.norm1.bias', 'base.stages.3.blocks.1.attn.w_msa.relative_position_bias_table', 'base.stages.3.blocks.1.attn.w_msa.relative_position_index', 'base.stages.3.blocks.1.attn.w_msa.qkv.weight', 'base.stages.3.blocks.1.attn.w_msa.qkv.bias', 'base.stages.3.blocks.1.attn.w_msa.proj.weight', 'base.stages.3.blocks.1.attn.w_msa.proj.bias', 'base.stages.3.blocks.1.norm2.weight', 'base.stages.3.blocks.1.norm2.bias', 'base.stages.3.blocks.1.ffn.layers.0.0.weight', 'base.stages.3.blocks.1.ffn.layers.0.0.bias', 'base.stages.3.blocks.1.ffn.layers.1.weight', 'base.stages.3.blocks.1.ffn.layers.1.bias', 'base.norm0.weight', 'base.norm0.bias', 'base.norm1.weight', 'base.norm1.bias', 'base.norm2.weight', 'base.norm2.bias', 'base.norm3.weight', 'base.norm3.bias', 'base.semantic_embed_w.0.weight', 'base.semantic_embed_w.0.bias', 'base.semantic_embed_w.1.weight', 'base.semantic_embed_w.1.bias', 'base.semantic_embed_w.2.weight', 'base.semantic_embed_w.2.bias', 'base.semantic_embed_w.3.weight', 'base.semantic_embed_w.3.bias', 'base.semantic_embed_b.0.weight', 'base.semantic_embed_b.0.bias', 'base.semantic_embed_b.1.weight', 'base.semantic_embed_b.1.bias', 'base.semantic_embed_b.2.weight', 'base.semantic_embed_b.2.bias', 'base.semantic_embed_b.3.weight', 'base.semantic_embed_b.3.bias'])
这个是self.state_dict()的keys,里面全是base xxx,没有bottleneck
您好,这是由于bottleneck的layer是在后面创建的,把添加的这段代码,即
if model_path != '':
print(self.state_dict().keys())
param_dict = torch.load(model_path)
for i in param_dict:
if 'classifier' in i:
continue
self.state_dict()[i].copy_(param_dict[i])
可以了,但是又有个新的问题:
Traceback (most recent call last):
File "C:\Users\alien\Documents\PyCharm-Projects\TIL-2023\CV\SOLIDER-REID\train.py", line 84, in <module>
do_train(
File "C:\Users\alien\Documents\PyCharm-Projects\TIL-2023\CV\SOLIDER-REID\processor\processor.py", line 94, in do_train
time_per_batch = (end_time - start_time) / (n_iter + 1)
UnboundLocalError: local variable 'n_iter' referenced before assignment
您看下是不是您读取数据的train_loader出了问题,导致没有进入下面这个for循环。 https://github.com/tinyvision/SOLIDER-REID/blob/8c08e1c3255e8e1e51e006bf189e52cc57b009ed/processor/processor.py#L51
print(len(train_loader.dataset)) 显示是3,我现在就是在测试所以只放了三张图片:
不过确实是没有进到loop里,我在排查为什么
找到原因了,因为我一共就三个图片,但是batchsize是64,把batchsize降到3就可以了,谢谢大佬
在尝试用自己的数据finetune,加载了MSMT17的预训练模型。用的指令如下
python train.py --config_file TIL.yml MODEL.PRETRAIN_CHOICE 'self' MODEL.PRETRAIN_PATH 'swin_base_msmt17.pth' OUTPUT_DIR './log' SOLVER.BASE_LR 0.0002 SOLVER.OPTIMIZER_NAME 'SGD' MODEL.SEMANTIC_WEIGHT 0.2
yml文件
已经根据https://github.com/tinyvision/SOLIDER-REID/issues/5#issuecomment-1528767312 替换了代码