Hello! I replaced conv, BN, Linear and other components in pytorch's official Resnet model with corresponding types in Torchmeta, the error occurred while I was loading the pre-training model in Pytorch.
Detail:
load pretrain model:
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(arch, block, layers, **kwargs)
if pretrained:
# load model we definitely
now_state_dict = model.state_dict()
print('model before update len:', len(now_state_dict))
# load pretrained model from url
pretrained_state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
print('raw pretrained params len: ', len(pretrained_state_dict))
# sample params
pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in now_state_dict}
print('selected pretrained params len: ', len(pretrained_state_dict))
# update my model with pretrained model
now_state_dict.update(pretrained_state_dict)
print('model after update len:', len(now_state_dict))
# load updated model
model.load_state_dict(now_state_dict)
# if pretrained:
# state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
# model.load_state_dict(state_dict)
return model
result:
This code is the Resnet that I incorporated Torchmeta into torchVision
def _forward_impl(self, x, params=None, get_feat=False):
x = self.conv1(x, params=self.get_subdict(params, 'conv1'))
x = self.bn1(x, params=self.get_subdict(params, 'bn1'))
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x, params=self.get_subdict(params, 'layer1'))
x = self.layer2(x, params=self.get_subdict(params, 'layer2'))
x = self.layer3(x, params=self.get_subdict(params, 'layer3'))
x = self.layer4(x, params=self.get_subdict(params, 'layer4'))
x = self.avgpool(x)
x = torch.flatten(x, 1)
feat = x
x = self.fc(x, params=self.get_subdict(params, 'fc'))
if get_feat:
return x, feat
else:
return x
Hello! I replaced conv, BN, Linear and other components in pytorch's official Resnet model with corresponding types in Torchmeta, the error occurred while I was loading the pre-training model in Pytorch.
Detail:
load pretrain model:
result:
This code is the Resnet that I incorporated Torchmeta into torchVision
:) Thanks!!