yitu-opensource / T2T-ViT

ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Other
1.14k stars 177 forks source link

How to load the provided model.pt.tar to train on my own dataset? #17

Closed shinianzhihou closed 3 years ago

shinianzhihou commented 3 years ago

If I want to load the provided model.pt.tar for initializing the weights and train on my own dataset, what should I do? Could you pls. give me some suggestions?

shinianzhihou commented 3 years ago

Solved as follows:

import torch

stat = torch.load('/path/to/80.6_T2T_ViT_14.pth.tar')

del stat['state_dict_ema']['head.weight']
del stat['state_dict_ema']['head.bias']

torch.save(stat['state_dict_ema'],'/path/to/80.6_T2T_ViT_14_single.pth.tar')

Some specific errors can be solved by modifying the timm.models.helpers

def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
    state_dict = load_state_dict(checkpoint_path, use_ema)
    model.load_state_dict(state_dict, strict=False) # here