xiaolai-sqlai / mobilenetv3

mobilenetv3 with pytorch,provide pre-train model
MIT License
1.6k stars 340 forks source link

The SE module is wrong, it's not just squeeze and expand the channel. #12

Open longxianlei opened 5 years ago

longxianlei commented 5 years ago

You should use the avg_pool to the input x, and use fc-->fc and expand as the input x's dimension. Then use the shortcut connection.

VincentChong123 commented 5 years ago

self.avg_pool below is not called.

class SeModule(nn.Module):
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.se = nn.Sequential(
            nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size),
            hsigmoid()
        )

    def forward(self, x):
        return x * self.se(x)