WongKinYiu / ScaledYOLOv4

Scaled-YOLOv4: Scaling Cross Stage Partial Network
GNU General Public License v3.0
2.02k stars 571 forks source link

Batch Norm layer missing #213

Closed eugene123tw closed 2 years ago

eugene123tw commented 3 years ago

In scaled-YOLOv4 or yolov4-csp, each Convolution layer always has a batch normalization layer following. However, this doesn't seem to be the case in this implementation.

For example, in BottleneckCSP, batchnorm layers are missing in both convolution layers cv2 and cv3. Even though there's a batchnorm layer after concatenation, it's still not the same as adding batchnorm layer after conv layer.

class BottleneckCSP(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(BottleneckCSP, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
        self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
        self.cv4 = Conv(2 * c_, c2, 1, 1)
        self.bn = nn.BatchNorm2d(2 * c_)  # applied to cat(cv2, cv3)
        self.act = Mish()
        self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.cv3(self.m(self.cv1(x)))
        y2 = self.cv2(x)
        return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))

As in https://github.com/AlexeyAB/darknet/blob/master/cfg/yolov4-csp.cfg, batchnorm layer is always included in a convolution layer (except yolo conv layers).

Did I miss something?

WongKinYiu commented 3 years ago

self.bn(torch.cat((y1, y2)

WongKinYiu commented 3 years ago

you can use Conv to replace nn.Conv2d and remove self.bn and self.act.

eugene123tw commented 3 years ago

Thank you for the speedy response. However, Batch norm layer after concatenation it's not the same as adding bn layer after each conv layer.

I think we could replace it with ConvBNMish as you suggested, but in this case, the pre-trained model wouldn't work, right?