lukemelas / EfficientNet-PyTorch

A PyTorch implementation of EfficientNet
Apache License 2.0
7.91k stars 1.53k forks source link

How to add additional layer in pre-trained model? #207

Open talhaanwarch opened 4 years ago

talhaanwarch commented 4 years ago

Can you please guide me how to add some extra fully connected layer on top of a pre-trained model

from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

I am confused, how to access the last layer and connect with another layer

lukemelas commented 4 years ago

You can do something like

model._fc = nn.Sequential(nn.Linear(self.network._fc.in_features, 512), 
                                           nn.ReLU(),  
                                           nn.Dropout(0.25),
                                           nn.Linear(512, 128), 
                                           nn.ReLU(),  
                                           nn.Dropout(0.50), 
                                           nn.Linear(128,classes))

or if you want to make bigger changes:

class MyEfficientNet(nn.Module):

    def __init__(self):
        super().__init__()

        # EfficientNet
        self.network = EfficientNet.from_pretrained("efficientnet-b0")

        # Replace last layer
        self.network._fc = nn.Sequential(nn.Linear(self.network._fc.in_features, 512), 
                                         nn.ReLU(),  
                                         nn.Dropout(0.25),
                                         nn.Linear(512, 128), 
                                         nn.ReLU(),  
                                         nn.Dropout(0.50), 
                                         nn.Linear(128,classes))

    def forward(self, x):
        out = self.network(x)
        return out

model = MyEfficientNet()

Look good?

sachinruk commented 4 years ago

just wondering if the last layer will still have a swish activation? When I print out the model, that seems to be the case. If so how do you remove that last layer?

Last few lines of output of print(model).

(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_avg_pooling): AdaptiveAvgPool2d(output_size=1)
    (_dropout): Dropout(p=0.2, inplace=False)
    (_fc): Sequential(
      (0): Linear(in_features=1280, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.25, inplace=False)
      (3): Linear(in_features=512, out_features=128, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.25, inplace=False)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
    (_swish): MemoryEfficientSwish()
  )
)
sachinruk commented 4 years ago

I've expanded on the question above on my SO question.