4uiiurz1 / pytorch-deform-conv-v2

PyTorch implementation of Deformable ConvNets v2 (Modulated Deformable Convolution)
MIT License
743 stars 141 forks source link

A large error when kernel_size > 3. #28

Open frotms opened 3 years ago

frotms commented 3 years ago

I ran dcnv2 with torchvision.ops.deform_conv2d, and got the same result with kernel_size=3. But got different result when kernel_size>3. My implementation of dcnv2 below:

def torch_initialize_weights(modules):
    # weight initialization
    for m in modules():
        if isinstance(m, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.BatchNorm2d):
            torch.nn.init.ones_(m.weight)
            torch.nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            torch.nn.init.normal_(m.weight, 0, 0.01)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.ConvTranspose2d):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

class TorchDeformableConvV2_split(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=False,
                 ):
        super(TorchDeformableConvV2, self).__init__()
        self.offset_channel = 2 * kernel_size**2
        self.mask_channel = kernel_size**2

        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.stride = stride

        self.conv_offset = torch.nn.Conv2d(in_channels,
                                           2 * kernel_size * kernel_size,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           padding=self.padding,
                                           bias=True)

        self.conv_modulator = torch.nn.Conv2d(in_channels,
                                              1 * kernel_size * kernel_size,
                                              kernel_size=kernel_size,
                                              stride=stride,
                                              padding=self.padding,
                                              bias=True)

        self.conv_dcn = torchvision.ops.DeformConv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=(kernel_size - 1) // 2 * dilation,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        torch_initialize_weights(self.modules)

    def forward(self, x):
        offset = self.conv_offset(x)
        mask = torch.sigmoid(self.conv_modulator(x))
        y = self.conv_dcn(x, offset, mask=mask)
        return y

Is there something wrong with my code?