pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.99k stars 6.92k forks source link

Simplify transfer learning by modifying get_model() #8631

Open david-csnmedia opened 1 week ago

david-csnmedia commented 1 week ago

🚀 The feature

Currently torchvision.models.get_model() doesn't allow you to build a model architecture with a different number of classes and keep existing pre-trained weights backbone for certain types (namely Image Classification models like EfficientNet).

Could something like this be incorporated into the get_model() method, or could another method be created to accommodate?


model = torchvision.models.get_model(self.model_type, weights=self.weights_backbone)

# fix the in/out features of the final layer of the classifier to match num_classes. 
# We have to do this after get_model() so we can retain the pre-trained weights, but 
# modify the model architecture for our use case.

classifier_layer = model.classifier
last_layer_index = len(classifier_layer) - 1

original_linear_layer = classifier_layer[last_layer_index]

new_linear_layer = torch.nn.Linear(in_features=original_linear_layer.in_features, out_features=self.num_classes)
classifier_layer[last_layer_index] = new_linear_layer

Motivation, pitch

Raising an error about the backbone weights having a mismatch guides users in a direction that isn't helpful.

Alternatives

No response

Additional context

No response

NicolasHug commented 1 week ago

Hi @david-csnmedia ,

that kind of model surgery is probably too specific to each model for it to be reliably implemented within get_model(). Note that some model builders allow num_classes to be passed.