RangiLyu / nanodet

NanoDet-Plus⚡Super fast and lightweight anchor-free object detection model. 🔥Only 980 KB(int8) / 1.8MB (fp16) and run 97FPS on cellphone🔥
Apache License 2.0
5.71k stars 1.04k forks source link

How to add MobileNetV3 backbone? #534

Closed nijatmursali closed 12 months ago

nijatmursali commented 1 year ago

I have trained my model using MobileNetV2, and want to train it with MobileNetV3, I have the backbone.py file as:

import torch
import torch.nn as nn
import math

class Hardswish(nn.Module):
    def forward(self, x):
        return x * torch.clamp(x + 3, 0, 6) / 6

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, exp_channels, kernel_size, stride, se_ratio, nl):
        super(Bottleneck, self).__init__()
        mid_channels = int(exp_channels)

        # Expansion convolution
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.nl = nl()

        # Depthwise convolution
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size, stride, kernel_size // 2, groups=mid_channels,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)

        # Squeeze-and-Excitation
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(mid_channels, int(mid_channels * se_ratio), 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(mid_channels * se_ratio), mid_channels, 1),
            nn.Sigmoid()
        )

        # Output convolution
        self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, 1, 0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.nl(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.nl(out)

        # Squeeze-and-Excitation
        w = self.se(out)
        out = out * w

        out = self.conv3(out)
        out = self.bn3(out)

        # Skip connection
        if self.stride == 1 and self.in_channels == self.out_channels:
            out = out + x

        return out

class MobileNetV3(nn.Module):
    def __init__(self, in_channels, num_classes, model_type='small', se_ratio=0.25):
        super(MobileNetV3, self).__init__()
        self.model_type = model_type

        if model_type == 'small':
            # Small version of MobileNetV3
            self.cfg = [
                # t, c, n, s
                [1, 16, 1, 1],
                [4, 24, 2, 2],
                [3, 40, 2, 2],
                [3, 80, 3, 2],
                [6, 112, 3, 1],
                [6, 160, 1, 2],
            ]
        else:
            # Large version of MobileNetV3
            self.cfg = [
                # t, c, n, s
                [1, 16, 1, 1],
                [4, 24, 2, 2],
                [3, 48, 3, 2],
                [3, 96, 3, 2],
                [6, 160, 3, 1],
                [6, 320, 1, 2],
            ]

        # Initial convolution
        self.conv1 = nn.Conv2d(in_channels, 16, 3, 2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.nl = nn.ReLU(inplace=True)

        # MobileNetV3 blocks
        layers = []
        in_channels = 16
        for t, c, n, s in self.cfg:
            out_channels = c
            for i in range(n):
                stride = s if i == 0 else 1
                layers.append(Bottleneck(in_channels, out_channels, t, 3, stride, se_ratio, Hardswish))
                in_channels = out_channels

        self.layers = nn.Sequential(*layers)

        # Final convolution
        self.conv2 = nn.Conv2d(out_channels, 1280, 1, 1, 0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)

        # Global average pooling and classifier
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.nl(out)

        out = self.layers(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.nl(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out

but it gives

RuntimeError: Error(s) in loading state_dict for TrainingTask:
        Missing key(s) in state_dict: "model.backbone.conv1.weight", ...