Closed shinianzhihou closed 3 years ago
Solved as follows:
import torch
stat = torch.load('/path/to/80.6_T2T_ViT_14.pth.tar')
del stat['state_dict_ema']['head.weight']
del stat['state_dict_ema']['head.bias']
torch.save(stat['state_dict_ema'],'/path/to/80.6_T2T_ViT_14_single.pth.tar')
Some specific errors can be solved by modifying the timm.models.helpers
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=False) # here
If I want to load the provided model.pt.tar for initializing the weights and train on my own dataset, what should I do? Could you pls. give me some suggestions?