Markin-Wang / FFVT

[BMVC 2021] The official PyTorch implementation of Feature Fusion Vision Transformer for Fine-Grained Visual Categorization
Other
49 stars 10 forks source link

How to load pretrained weights? #2

Open magenta2n opened 2 years ago

magenta2n commented 2 years ago

I wanted to load the pretrained weights but on specifying --pretrained_dir as the absolute path of the weights file: soyloc.bin, I get the following error:

WARNING - main - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False Traceback (most recent call last): File "/FFVT-main/train.py", line 416, in main() File "/home/code/classifiers/FFVT-main/train.py", line 409, in main args, model = setup(args) File "/FFVT-main/train.py", line 93, in setup model.load_from(np.load(args.pretrained_dir, allow_pickle=True)) File "/FFVT-main/models/modeling.py", line 379, in load_from self.transformer.embeddings.patchembeddings.weight.copy(np2th(weights["embedding/kernel"], conv=True)) TypeError: 'int' object is not subscriptable

And when I try to load weights I generated(--pretrained_dir= my_ckpt.bin) after training with a custom dataset I get:

Traceback (most recent call last): File "/FFVT-main/train.py", line 416, in main() File "/FFVT-main/train.py", line 409, in main args, model = setup(args) File "/FFVT-main/train.py", line 93, in setup model.load_from(np.load(args.pretrained_dir, allow_pickle=True)) File "/FFVT-main/models/modeling.py", line 379, in load_from self.transformer.embeddings.patchembeddings.weight.copy(np2th(weights["embedding/kernel"], conv=True)) File "/PyEnvs/base/lib/python3.8/site-packages/numpy/lib/npyio.py", line 260, in getitem raise KeyError("%s is not a file in the archive" % key) KeyError: 'embedding/kernel is not a file in the archive'

Markin-Wang commented 2 years ago

Hi, thanks for your interest. The pre-trained model is saved by torch.save(model_to_save.state_dict(), model_checkpoint) Hence, you cannot use the way loading pretrained ViT weight (.npz) to load our trained model.

You should use the model.load_state_dict(torch.load(PATH)) to load our trained models. Link below shows the tutorial about how to save and load weights in pytorch, hope it helps. How to save and load models

In addition, I would recommend you to write a script to perform the test based on the main and valid functions in train.py.