Open cuicuizhang1989 opened 5 years ago
@excllent123 你好,需要怎么修改模型才能加载?我直接加载模型也会出错
def load_checkpoint(model, checkpoint_PATH): if checkpoint_PATH != None: model_CKPT = torch.load(checkpoint_PATH) model.load_state_dict({k.replace('module.', ''): v for k, v in model_CKPT['state_dict'].items()}) print('loading checkpoint!') return model
def mobilenet_large_v3(pretrained=False,kwargs): if pretrained: model = MobileNetV3_Large(kwargs) return load_checkpoint(model,'mbv3_large.pth.tar')
return MobileNetV3_Large(**kwargs)
class Finetune_MobileNetV3_Large(nn.Module): def init(self,class_nums): super(Finetune_MobileNetV3_Large, self).init() self.class_nums = class_nums self.base = mobilenet_large_v3(pretrained=True) self.base.linear4 = nn.Linear(1280, self.class_nums)
def forward(self, x):
x = self.base(x)
return x
感谢
直接加载模型就行,多GPU训练的,可能需要对模型进行修改后加载