Alibaba-AAIG / SSL-FEW-SHOT

SSL-FEW-SHOT
MIT License
171 stars 37 forks source link

can't load pre-trained model in eval_protonet #16

Open YIYIZH opened 3 years ago

YIYIZH commented 3 years ago

image model.load_state_dict(torch.load(args.model_path)['params']) Hi, I got an error in the above line of code, saying that the loaded model (downloaded from the proposed google drive)has no key of 'params'. I checked and found that the key should be 'model'. So I changed to model.load_state_dict(torch.load(args.model_path)['model']), another error occurred, saying there are missed keys. What's wrong with the pre-trained model or the code, please?

eladmeir commented 3 years ago

Hey @YIYIZH The pre-trained models provided here are the pre-trained SSL for the general purpose embedding network You cannot evaluate it as is, but instead - you should first fine-tune the network, using this pre-trained weights as initialization, using one of the 3 provided few-shot datasets (or even make your own custom dataset), using train_protonet.py

For me, it took ~15 hours to train on a single 8GB RTX 2080, but it was really tight on the GPU's RAM.. I'd consider a 12+GB GPU for this task

Once you've got the training done - you'll have a new model weights file for evaluation