I ran dcnv2 with torchvision.ops.deform_conv2d, and got the same result with kernel_size=3.
But got different result when kernel_size>3.
My implementation of dcnv2 below:
def torch_initialize_weights(modules):
# weight initialization
for m in modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.BatchNorm2d):
torch.nn.init.ones_(m.weight)
torch.nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.ConvTranspose2d):
torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
class TorchDeformableConvV2_split(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
):
super(TorchDeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2
self.mask_channel = kernel_size**2
self.padding = padding
self.dilation = dilation
self.groups = groups
self.stride = stride
self.conv_offset = torch.nn.Conv2d(in_channels,
2 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True)
self.conv_modulator = torch.nn.Conv2d(in_channels,
1 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True)
self.conv_dcn = torchvision.ops.DeformConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 * dilation,
dilation=dilation,
groups=groups,
bias=bias,
)
torch_initialize_weights(self.modules)
def forward(self, x):
offset = self.conv_offset(x)
mask = torch.sigmoid(self.conv_modulator(x))
y = self.conv_dcn(x, offset, mask=mask)
return y
I ran
dcnv2
withtorchvision.ops.deform_conv2d
, and got the same result withkernel_size=3
. But got different result whenkernel_size>3
. My implementation ofdcnv2
below:Is there something wrong with my code?