jbwang1997 / OBBDetection

OBBDetection is an oriented object detection library, which is based on MMdetection.
Apache License 2.0
522 stars 112 forks source link

如何修改roi的w和h #71

Closed qfwysw closed 2 years ago

qfwysw commented 2 years ago

感谢您开源优质的代码! 因为任务的需要, 我需要将roi的w和h设置为不同大小, 但是在反向传播的时候报错说需要7个参数, 但是仅仅返回了6个参数, 我仅仅为返回值多加了一个None解决了报错, 不过我不太确定这种改动是否正确, 会对反向传播是否有影响 `class RoIAlignRotatedFunction(Function):

@staticmethod
def forward(ctx,
            features,
            rois,
            out_w,
            out_h,
            spatial_scale,
            sample_num=0,
            aligned=True):
    out_h, out_w = out_h, out_w
    assert isinstance(out_h, int) and isinstance(out_w, int)
    ctx.spatial_scale = spatial_scale
    ctx.sample_num = sample_num
    ctx.save_for_backward(rois)
    ctx.feature_size = features.size()
    ctx.aligned = aligned

    output = roi_align_rotated_ext.forward(
        features, rois, spatial_scale, out_h, out_w, sample_num, aligned)

    return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output):
    feature_size = ctx.feature_size
    spatial_scale = ctx.spatial_scale
    sample_num = ctx.sample_num
    rois = ctx.saved_tensors[0]
    aligned = ctx.aligned
    assert feature_size is not None

    batch_size, num_channels, data_height, data_width = feature_size
    out_w = grad_output.size(3)
    out_h = grad_output.size(2)

    grad_input = grad_rois = None
    grad_input = roi_align_rotated_ext.backward(
        grad_output, rois, spatial_scale, out_h, out_w,
        batch_size, num_channels, data_height, data_width,
        sample_num, aligned)

    return grad_input, grad_rois, None, None, None, None, None`

期待您的回复

jbwang1997 commented 2 years ago

如果是单纯改变RoI的W与H的大小,可以在config文件中将roi_layer=dict(type='RoIAlignRotated', out_size=7, sample_num=2)中的out_size=7改为out_size=(h, w).