OpenGVLab / DCNv4

[CVPR 2024] Deformable Convolution v4
https://arxiv.org/pdf/2401.06197.pdf
MIT License
516 stars 27 forks source link

DCNv4相关 #12

Closed qq1361096516 closed 10 months ago

qq1361096516 commented 10 months ago

DCNv4_op\DCNv4\modules\dcnv4.py 128行

x_proj = x

x = DCNv4Function.apply(
    x, offset_mask,
    self.kernel_size, self.kernel_size,
    self.stride, self.stride,
    self.pad, self.pad,
    self.dilation, self.dilation,
    self.group, self.group_channels,
    self.offset_scale,
    256,
    self.remove_center
    )

x = x.view(N, L, -1)
if self.center_feature_scale:
    center_feature_scale = self.center_feature_scale_module(
        x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
    center_feature_scale = center_feature_scale[..., None].repeat(
        1, 1, 1, 1, self.channels // self.group).flatten(-2)
    x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
if not self.without_pointwise:
    x = self.output_proj(x)

x_proj shape is [bs, h, w, c], after x = x.view(N, L, -1) x shape is *[bs, h w, c] when use center_feature_scale*: `x = x (1 - center_feature_scale) + x_proj * center_feature_scale` will cause shape missmatch error

my solution is put x = x.view(N, L, -1) before if not self.without_pointwise: like:

x_proj = x

x = DCNv4Function.apply(
    x, offset_mask,
    self.kernel_size, self.kernel_size,
    self.stride, self.stride,
    self.pad, self.pad,
    self.dilation, self.dilation,
    self.group, self.group_channels,
    self.offset_scale,
    256,
    self.remove_center
    )

if self.center_feature_scale:
    center_feature_scale = self.center_feature_scale_module(
        x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
    center_feature_scale = center_feature_scale[..., None].repeat(
        1, 1, 1, 1, self.channels // self.group).flatten(-2)
    x = x * (1 - center_feature_scale) + x_proj * center_feature_scale

x = x.view(N, L, -1)  # move to here

if not self.without_pointwise:
    x = self.output_proj(x)
YuwenXiong commented 10 months ago

Hi, thanks for raising this up; we never tried center_feature_scale in our experiments, so we were not aware of that. Now this issue has been fixed with the 1.0.0.post2 release