Closed skfaysal closed 2 years ago
Hi @skfaysal, I am not sure for CALF, but I can give you some insights for the pooling approaches. All you need to do is reset the self.fc layers in models.py and fine-tune with your new data.
However, note that:
I hope that helps!
Hi @SilvioGiancola ,Your quick response is much appreciated. for NetVLAD++ pooling in self.fc layer: self.fc = nn.Linear(input_size*self.vlad_k, self.num_classes+1) I've changed the self. num_classes to my desired class. Then i pass the trained model "model.pth.tar" as --load_weights parameter. I get below error:
We can see this error occurs because of the mismatch between fc layer of pretrained model which is (18,32768) and the custom model which is (3,32768). here 18 and 3 are output class. I'm not shure how to change the fc layer of pretrained model and start training the custom model using the weights of the pretrained model.
Please help!
You should be able to load the weights of the pre-trained model by setting the argument strict=False
of the function load_state_dict
in L91 of model.py.
I saw you were playing around with that option already, so that should not appear anymore. Maybe make sure you have the same version of pytorch I was using, or at least a version that handle that handle strict=False
.
You can see the complete documentation for load_state_dict
on https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict . Check the note, maybe you are in this particular case that return a RuntimeError.
This issue has been resolved by doing the following steps:
checkpoint = torch.load(os.path.join("models", args.load_weights, "model.pth.tar")) model.load_state_dict(checkpoint['state_dict'])
num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes).cuda()
@SilvioGiancola Thanks for your respose.
How can i fine tune the "model.pth.tar" model for a dataset that has different output class? lets say the pretrained model("model.pth.tar") has 17 output class but i want to fine-tune it for a dataset which has 4 class.
Originally posted by @skfaysal in https://github.com/SilvioGiancola/SoccerNetv2-DevKit/issues/37#issuecomment-1186591705