x_proj = x
x = DCNv4Function.apply(
x, offset_mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
self.remove_center
)
x = x.view(N, L, -1)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
if not self.without_pointwise:
x = self.output_proj(x)
x_proj shape is [bs, h, w, c], after x = x.view(N, L, -1)x shape is *[bs, h w, c]
when use center_feature_scale*:
`x = x (1 - center_feature_scale) + x_proj * center_feature_scale` will cause shape missmatch error
my solution is put x = x.view(N, L, -1) before if not self.without_pointwise:
like:
x_proj = x
x = DCNv4Function.apply(
x, offset_mask,
self.kernel_size, self.kernel_size,
self.stride, self.stride,
self.pad, self.pad,
self.dilation, self.dilation,
self.group, self.group_channels,
self.offset_scale,
256,
self.remove_center
)
if self.center_feature_scale:
center_feature_scale = self.center_feature_scale_module(
x, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias)
center_feature_scale = center_feature_scale[..., None].repeat(
1, 1, 1, 1, self.channels // self.group).flatten(-2)
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
x = x.view(N, L, -1) # move to here
if not self.without_pointwise:
x = self.output_proj(x)
Hi, thanks for raising this up; we never tried center_feature_scale in our experiments, so we were not aware of that. Now this issue has been fixed with the 1.0.0.post2 release
DCNv4_op\DCNv4\modules\dcnv4.py 128行
x_proj shape is [bs, h, w, c], after
x = x.view(N, L, -1)
x shape is *[bs, h w, c] when use center_feature_scale*: `x = x (1 - center_feature_scale) + x_proj * center_feature_scale` will cause shape missmatch errormy solution is put
x = x.view(N, L, -1)
beforeif not self.without_pointwise:
like: