Open magenta2n opened 2 years ago
Hi, thanks for your interest.
The pre-trained model is saved by
torch.save(model_to_save.state_dict(), model_checkpoint)
Hence, you cannot use the way loading pretrained ViT weight (.npz) to load our trained model.
You should use the model.load_state_dict(torch.load(PATH)) to load our trained models. Link below shows the tutorial about how to save and load weights in pytorch, hope it helps. How to save and load models
In addition, I would recommend you to write a script to perform the test based on the main and valid functions in train.py.
I wanted to load the pretrained weights but on specifying --pretrained_dir as the absolute path of the weights file: soyloc.bin, I get the following error:
WARNING - main - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False Traceback (most recent call last): File "/FFVT-main/train.py", line 416, in
main()
File "/home/code/classifiers/FFVT-main/train.py", line 409, in main
args, model = setup(args)
File "/FFVT-main/train.py", line 93, in setup
model.load_from(np.load(args.pretrained_dir, allow_pickle=True))
File "/FFVT-main/models/modeling.py", line 379, in load_from
self.transformer.embeddings.patchembeddings.weight.copy(np2th(weights["embedding/kernel"], conv=True))
TypeError: 'int' object is not subscriptable
And when I try to load weights I generated(--pretrained_dir= my_ckpt.bin) after training with a custom dataset I get:
Traceback (most recent call last): File "/FFVT-main/train.py", line 416, in
main()
File "/FFVT-main/train.py", line 409, in main
args, model = setup(args)
File "/FFVT-main/train.py", line 93, in setup
model.load_from(np.load(args.pretrained_dir, allow_pickle=True))
File "/FFVT-main/models/modeling.py", line 379, in load_from
self.transformer.embeddings.patchembeddings.weight.copy(np2th(weights["embedding/kernel"], conv=True))
File "/PyEnvs/base/lib/python3.8/site-packages/numpy/lib/npyio.py", line 260, in getitem
raise KeyError("%s is not a file in the archive" % key)
KeyError: 'embedding/kernel is not a file in the archive'