supernotman / RetinaFace_Pytorch

Reimplement RetinaFace with Pytorch
305 stars 68 forks source link

About context module #12

Closed luyao777 closed 5 years ago

luyao777 commented 5 years ago

I think there maybe some mistakes of channels in context module

x1 = self.det_conv1(x) # 256 channels
x_ = self.det_context_conv1(x) # 128 channels
x2 = self.det_context_conv2(x_) # 128 channels
x3_ = self.det_context_conv3_1(x_) # 128 channels
x3 = self.det_context_conv3_2(x3_) # 128 channels

and after concat x1,x2,x3 I got 512 channels. This is inconsistent with the paper.(256 channels) Is there anything wrong with me?

luyao777 commented 5 years ago

I think the context module should be:

class ContextModule(nn.Module):
    def __init__(self,in_channels=256):
        super(ContextModule,self).__init__()
        self.det_context_conv1 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels),
        )
        self.det_context_conv2 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//2,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//2), 
        )
        self.det_context_conv3 = nn.Sequential(
            nn.Conv2d(in_channels//2,in_channels//4,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//4),
        )
        self.det_context_conv4 = nn.Sequential(
            nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//4),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        x = self.det_context_conv1(x) # 256
        x1 = self.det_context_conv2(self.relu(x)) # 128
        x2 = self.det_context_conv3(self.relu(x1)) # 64
        x3 = self.det_context_conv4(self.relu(x2)) # 64

        out = torch.cat((x1,x2,x3),1)
        act_out = self.relu(out)

        return act_out

Thanks for your great work again.

supernotman commented 5 years ago

I think there maybe some mistakes of channels in context module

x1 = self.det_conv1(x) # 256 channels
x_ = self.det_context_conv1(x) # 128 channels
x2 = self.det_context_conv2(x_) # 128 channels
x3_ = self.det_context_conv3_1(x_) # 128 channels
x3 = self.det_context_conv3_2(x3_) # 128 channels

and after concat x1,x2,x3 I got 512 channels. This is inconsistent with the paper.(256 channels) Is there anything wrong with me?

There is a version i tried about 256 channels.but if you visualize the network R50, the channels after context module are truly 512,which is different from the papar.And i am confused about this too. By the way ,another different is that i found in R50, not a shared loss head used as said in paper.But i tested this made a limit influence on results. You can also check the code or visualize the module for i am not sure about if these are truly different.

luyao777 commented 5 years ago

It's weird. I try r50 in fpn of mmdetection. Perhaps the latter half of the network is used as output.

neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=1,
        add_extra_convs=True,
        num_outs=5),

Its default output channel is 256. You can refer to it.

supernotman commented 5 years ago
in_channels=[256, 512, 1024, 2048]

yes, input feature channels are [256, 512, 1024, 2048] ,which is normal for resnet with layers more than 50,like resnet50/101/152. And firstly all of them are resized to 256 channels,after that, for every feature, 256 channels are separately resized to [256,128,128]. Finally, sum them up ,it will be 256+128+128=512, that is what i see in module visualization.

foocker commented 5 years ago
in_channels=[256, 512, 1024, 2048]

yes, input feature channels are [256, 512, 1024, 2048] ,which is normal for resnet with layers more than 50,like resnet50/101/152. And firstly all of them are resized to 256 channels,after that, for every feature, 256 channels are separately resized to [256,128,128]. Finally, sum them up ,it will be 256+128+128=512, that is what i see in module visualization.

yes!

luyao777 commented 5 years ago

So is there any mistakes in my implement? I think maybe it's closer to paper implement. :bowtie:

I think the context module should be:

class ContextModule(nn.Module):
    def __init__(self,in_channels=256):
        super(ContextModule,self).__init__()
        self.det_context_conv1 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels),
        )
        self.det_context_conv2 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//2,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//2), 
        )
        self.det_context_conv3 = nn.Sequential(
            nn.Conv2d(in_channels//2,in_channels//4,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//4),
        )
        self.det_context_conv4 = nn.Sequential(
            nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//4),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        x = self.det_context_conv1(x) # 256
        x1 = self.det_context_conv2(self.relu(x)) # 128
        x2 = self.det_context_conv3(self.relu(x1)) # 64
        x3 = self.det_context_conv4(self.relu(x2)) # 64

        out = torch.cat((x1,x2,x3),1)
        act_out = self.relu(out)

        return act_out

Thanks for your great work again.

supernotman commented 5 years ago

So is there any mistakes in my implement? I think maybe it's closer to paper implement. :bowtie:

I think the context module should be:

class ContextModule(nn.Module):
    def __init__(self,in_channels=256):
        super(ContextModule,self).__init__()
        self.det_context_conv1 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels),
        )
        self.det_context_conv2 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//2,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//2), 
        )
        self.det_context_conv3 = nn.Sequential(
            nn.Conv2d(in_channels//2,in_channels//4,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//4),
        )
        self.det_context_conv4 = nn.Sequential(
            nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels//4),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        x = self.det_context_conv1(x) # 256
        x1 = self.det_context_conv2(self.relu(x)) # 128
        x2 = self.det_context_conv3(self.relu(x1)) # 64
        x3 = self.det_context_conv4(self.relu(x2)) # 64

        out = torch.cat((x1,x2,x3),1)
        act_out = self.relu(out)

        return act_out

Thanks for your great work again.

Sorry for the late reply, if you want a 256 output channels context module, at least activation function (rule ) is needed after bn. You can reference to model.py line 231 class Context. Good luck.

luyao777 commented 5 years ago

Thank you very much!