SilvioGiancola / SoccerNetv2-DevKit

Development Kit for the SoccerNet Challenge
MIT License
165 stars 38 forks source link

Fine-tuning pretrained model for different class #50

Closed skfaysal closed 2 years ago

skfaysal commented 2 years ago

How can i fine tune the "model.pth.tar" model for a dataset that has different output class? lets say the pretrained model("model.pth.tar") has 17 output class but i want to fine-tune it for a dataset which has 4 class.

Originally posted by @skfaysal in https://github.com/SilvioGiancola/SoccerNetv2-DevKit/issues/37#issuecomment-1186591705

SilvioGiancola commented 2 years ago

Hi @skfaysal, I am not sure for CALF, but I can give you some insights for the pooling approaches. All you need to do is reset the self.fc layers in models.py and fine-tune with your new data.

However, note that:

I hope that helps!

skfaysal commented 2 years ago

Hi @SilvioGiancola ,Your quick response is much appreciated. for NetVLAD++ pooling in self.fc layer: self.fc = nn.Linear(input_size*self.vlad_k, self.num_classes+1) I've changed the self. num_classes to my desired class. Then i pass the trained model "model.pth.tar" as --load_weights parameter. I get below error:

Screenshot 2022-07-19 at 1 01 36 AM

We can see this error occurs because of the mismatch between fc layer of pretrained model which is (18,32768) and the custom model which is (3,32768). here 18 and 3 are output class. I'm not shure how to change the fc layer of pretrained model and start training the custom model using the weights of the pretrained model.

Please help!

SilvioGiancola commented 2 years ago

You should be able to load the weights of the pre-trained model by setting the argument strict=False of the function load_state_dict in L91 of model.py.

I saw you were playing around with that option already, so that should not appear anymore. Maybe make sure you have the same version of pytorch I was using, or at least a version that handle that handle strict=False.

You can see the complete documentation for load_state_dict on https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict . Check the note, maybe you are in this particular case that return a RuntimeError.

skfaysal commented 2 years ago

This issue has been resolved by doing the following steps:

  1. Keep the model architecture as it is
  2. Load the pretrained model weight and "state_dict" to match with the architecture. checkpoint = torch.load(os.path.join("models", args.load_weights, "model.pth.tar")) model.load_state_dict(checkpoint['state_dict'])
  3. Get the input features from fc layer and change the output node as per dataset num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes).cuda()

@SilvioGiancola Thanks for your respose.