tristandeleu / pytorch-meta

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch
https://tristandeleu.github.io/pytorch-meta/
MIT License
1.98k stars 256 forks source link

How do I use a pre-trained model in Torchmeta? #109

Open ximinng opened 3 years ago

ximinng commented 3 years ago

Hello! I replaced conv, BN, Linear and other components in pytorch's official Resnet model with corresponding types in Torchmeta, the error occurred while I was loading the pre-training model in Pytorch. image

Detail:

load pretrain model:

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(arch, block, layers, **kwargs)
    if pretrained:
        # load model we definitely
        now_state_dict = model.state_dict()
        print('model before update len:', len(now_state_dict))

        # load pretrained model from url
        pretrained_state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        print('raw pretrained params len: ', len(pretrained_state_dict))

        # sample params
        pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in now_state_dict}
        print('selected pretrained params len: ', len(pretrained_state_dict))

        # update my model with pretrained model
        now_state_dict.update(pretrained_state_dict)
        print('model after update len:', len(now_state_dict))

        # load updated model
        model.load_state_dict(now_state_dict)

    # if pretrained:
    #     state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
    #     model.load_state_dict(state_dict)
    return model

result:

image

This code is the Resnet that I incorporated Torchmeta into torchVision

    def _forward_impl(self, x, params=None, get_feat=False):
        x = self.conv1(x, params=self.get_subdict(params, 'conv1'))
        x = self.bn1(x, params=self.get_subdict(params, 'bn1'))
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x, params=self.get_subdict(params, 'layer1'))
        x = self.layer2(x, params=self.get_subdict(params, 'layer2'))
        x = self.layer3(x, params=self.get_subdict(params, 'layer3'))
        x = self.layer4(x, params=self.get_subdict(params, 'layer4'))

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        feat = x
        x = self.fc(x, params=self.get_subdict(params, 'fc'))

        if get_feat:
            return x, feat
        else:
            return x

:) Thanks!!

Jibanul commented 3 years ago

Did you figure it out?