shanglianlm0525 / PyTorch-Networks

Pytorch implementation of cnn network
1.91k stars 485 forks source link

one question about LW bottleneck #26

Open linecodezp opened 2 years ago

linecodezp commented 2 years ago

class LWbottleneck(nn.Module): def init(self, in_channels,out_channels,stride): super(LWbottleneck, self).init() self.stride = stride self.pyramid_list = nn.ModuleList() self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=[5,1], stride=stride, padding=[2,0])) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=[1,5], stride=stride, padding=[0,2])) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=[3,1], stride=stride, padding=[1,0])) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=[1,3], stride=stride, padding=[0,1])) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=[2,1], stride=stride, padding=[1,0])) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=[1,2], stride=stride, padding=[0,1])) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=2, stride=stride, padding=1)) self.pyramid_list.append(ConvBNReLU(in_channels, in_channels, kernel_size=3, stride=stride, padding=1))

    self.shrink = Conv1x1BN(in_channels*8,out_channels)

def forward(self, x):
    b,c,w,h = x.shape
    if self.stride>1:
        w, h = w//self.stride,h//self.stride
    outputs = []
    for pyconv in self.pyramid_list:
        pyconv_x = pyconv(x)
        if x.shape[2:] != pyconv_x.shape[2:]:
            pyconv_x = pyconv_x[:,:,:w,:h]
        outputs.append(pyconv_x)
    out = torch.cat(outputs, 1)
    return self.shrink(out)

In the original paper, They used shotcut and add, which seems not to be in your code