lukemelas / EfficientNet-PyTorch

A PyTorch implementation of EfficientNet
Apache License 2.0
7.85k stars 1.52k forks source link

Bug for multitask learning and all the fc layer replacement situations #283

Open xingjchen opened 3 years ago

xingjchen commented 3 years ago

Hi, I have a problem with multi-task learning when using EfficientNet.

I want to load the pre-trained model and remove the raw FC layers, then add my own 2 new FC layers for 2 different tasks.

However, it always shows the following bug:

File "/home/user/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() takes 1 positional argument but 2 were given

and here is my code:

class AuxNet(torch.nn.Module):
    def __init__(self, original_model, num_classes, num_classes_aux):
        super(AuxNet, self).__init__()

        input_features = original_model._fc.in_features

        self.network = torch.nn.Sequential(*list(original_model.children())[:-2])

        self.fc1 = torch.nn.Sequential( torch.nn.Linear(input_features, num_classes))

        self.fc2 = torch.nn.Sequential( torch.nn.Linear(input_features, num_classes_aux))

    def forward(self, x):

        f = self.network(x)
        y1 = self.fc1(f)
        y2 = self.fc2(f)

        return y1, y2

model = AuxNet(effi_model,512,128)
print(model)
X = torch.rand(8, 3,224,224)
model(X)

Is there any solution? There is the same issue proposed here https://stackoverflow.com/questions/62954999/flattening-efficientnet-model

Thank you very much!

xingjchen commented 3 years ago

I think I get the right solution now.

We just need to replace the '_fc' layer:

class AuxNet(torch.nn.Module):
    def __init__(self, input_features):
        super(AuxNet, self).__init__()

        self.fc1 = torch.nn.Sequential( torch.nn.Linear(input_features, 512))
        self.fc2 = torch.nn.Sequential( torch.nn.Linear(input_features, 128))

    def forward(self, x):

        y1 = self.fc1(x)
        y2 = self.fc2(x)

        return y1,y2

effi_model = EfficientNet.from_pretrained('efficientnet-b7')
feature = effi_model._fc.in_features

effi_model._fc = AuxNet(feature)
X = torch.rand(8, 3,224,224)
effi_model(X)

And this kind of solution can be applied to all similar cases.