class GhostNet(nn.Module):
def init(self, pretrained=True):
super(GhostNet, self).init()
model = ghostnet()
if pretrained:
state_dict = torch.load("model_data/ghostnet_weights.pth")
model.load_state_dict(state_dict)
del model.global_pool
del model.conv_head
del model.act2
del model.classifier
del model.blocks[9]
self.model = model
def forward(self, x):
x = self.model.conv_stem(x)
x = self.model.bn1(x)
x = self.model.act1(x)
feature_maps = []
for idx, block in enumerate(self.model.blocks):
x = block(x)
if idx in [2,4,6,8]:
feature_maps.append(x)
return feature_maps[1:]
class GhostNet(nn.Module): def init(self, pretrained=True): super(GhostNet, self).init() model = ghostnet() if pretrained: state_dict = torch.load("model_data/ghostnet_weights.pth") model.load_state_dict(state_dict) del model.global_pool del model.conv_head del model.act2 del model.classifier del model.blocks[9] self.model = model
需要加载 model_data/ghostnet_weights.pth,发现模型不匹配,我的ghostnet分类模型的特征提取网络是需要重新训练吗