wilile26811249 / MobileViT

Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".
MIT License
112 stars 20 forks source link

load pretrain weight failed #6

Closed wilbur-caper closed 2 years ago

wilbur-caper commented 2 years ago
import torch
import models

model = models.MobileViT_S()
PATH = "./MobileVit-S.pth.tar"
weights = torch.load(PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(weights['state_dict'])
model.eval()
torch.save(model, './model.pt')

image

wilile26811249 commented 2 years ago

Hi, @hererookie sorry for reply so late. Because, I wrapped this model into DataParallel when training, which stores the model in module.

Sol: model.load_state_dict({k.replace('module.',''):v for k,v in weights['state_dict'].items()}

Ref: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3

wilbur-caper commented 2 years ago

Hi, @hererookie sorry for reply so late. Because, I wrapped this model into DataParallel when training, which stores the model in module.

Sol: model.load_state_dict({k.replace('module.',''):v for k,v in weights['state_dict'].items()}

Ref: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3

thanks for your reply, I will try