Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.
MIT License
4.88k stars 527 forks source link

Does it support DCN V2? Or shall we add it as custom ops? #153

Open laisimiao opened 3 years ago

Lyken17 commented 3 years ago

Can you share the model definition? For most cases where the models are written with nn.Modules thop should work.

laisimiao commented 3 years ago

It's deformable convolution version2. It need compilation into .so file and it's official site is: https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op

class DeformConvFunction(Function):
    @staticmethod
    def forward(
        ctx, 
        input, 
        offset, 
        weight,
        stride=1, 
        padding=0, 
        dilation=1, 
        groups=1, 
        deformable_groups=1, 
        im2col_step=64
    ):
        if input is not None and input.dim() != 4:
            raise ValueError(
                "Expected 4D tensor as input, got {}D tensor instead.".format(
                    input.dim()))
        ctx.stride = _pair(stride)
        ctx.padding = _pair(padding)
        ctx.dilation = _pair(dilation)
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.im2col_step = im2col_step

        ctx.save_for_backward(input, offset, weight)

        output = input.new_empty(
            DeformConvFunction._output_size(input, weight, ctx.padding,
                                            ctx.dilation, ctx.stride))

        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones

        if not input.is_cuda:
            raise NotImplementedError
        else:
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
            assert (input.shape[0] %
                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
            _C.deform_conv_forward(
                input, 
                weight, 
                offset, 
                output, 
                ctx.bufs_[0], 
                ctx.bufs_[1],
                weight.size(3), 
                weight.size(2), 
                ctx.stride[1], 
                ctx.stride[0],
                ctx.padding[1], 
                ctx.padding[0], 
                ctx.dilation[1],
                ctx.dilation[0], 
                ctx.groups, 
                ctx.deformable_groups,
                cur_im2col_step
            )
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        input, offset, weight = ctx.saved_tensors

        grad_input = grad_offset = grad_weight = None

        if not grad_output.is_cuda:
            raise NotImplementedError
        else:
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
            assert (input.shape[0] %
                    cur_im2col_step) == 0, 'im2col step must divide batchsize'

            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                grad_input = torch.zeros_like(input)
                grad_offset = torch.zeros_like(offset)
                _C.deform_conv_backward_input(
                    input, 
                    offset, 
                    grad_output, 
                    grad_input,
                    grad_offset, 
                    weight, 
                    ctx.bufs_[0], 
                    weight.size(3),
                    weight.size(2), 
                    ctx.stride[1], 
                    ctx.stride[0],
                    ctx.padding[1], 
                    ctx.padding[0], 
                    ctx.dilation[1],
                    ctx.dilation[0], 
                    ctx.groups, 
                    ctx.deformable_groups,
                    cur_im2col_step
                )

            if ctx.needs_input_grad[2]:
                grad_weight = torch.zeros_like(weight)
                _C.deform_conv_backward_parameters(
                    input, 
                    offset, 
                    grad_output,
                    grad_weight, 
                    ctx.bufs_[0], 
                    ctx.bufs_[1], 
                    weight.size(3),
                    weight.size(2), 
                    ctx.stride[1], 
                    ctx.stride[0],
                    ctx.padding[1], 
                    ctx.padding[0], 
                    ctx.dilation[1],
                    ctx.dilation[0], 
                    ctx.groups, 
                    ctx.deformable_groups, 
                    1,
                    cur_im2col_step
                )

        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)

    @staticmethod
    def _output_size(input, weight, padding, dilation, stride):
        channels = weight.size(0)
        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
            pad = padding[d]
            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
            stride_ = stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError(
                "convolution input is too small (output would be {})".format(
                    'x'.join(map(str, output_size))))
        return output_size

class ModulatedDeformConvFunction(Function):

    @staticmethod
    def forward(
        ctx,
        input,
        offset,
        mask,
        weight,
        bias=None,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        deformable_groups=1
    ):
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.with_bias = bias is not None
        if not ctx.with_bias:
            bias = input.new_empty(1)  # fake tensor
        if not input.is_cuda:
            raise NotImplementedError
        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                or input.requires_grad:
            ctx.save_for_backward(input, offset, mask, weight, bias)
        output = input.new_empty(
            ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
        _C.modulated_deform_conv_forward(
            input, 
            weight, 
            bias, 
            ctx._bufs[0], 
            offset, 
            mask, 
            output,
            ctx._bufs[1], 
            weight.shape[2], 
            weight.shape[3], 
            ctx.stride,
            ctx.stride, 
            ctx.padding, 
            ctx.padding, 
            ctx.dilation, 
            ctx.dilation,
            ctx.groups, 
            ctx.deformable_groups, 
            ctx.with_bias
        )
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError
        input, offset, mask, weight, bias = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        grad_offset = torch.zeros_like(offset)
        grad_mask = torch.zeros_like(mask)
        grad_weight = torch.zeros_like(weight)
        grad_bias = torch.zeros_like(bias)
        _C.modulated_deform_conv_backward(
            input, 
            weight, 
            bias, 
            ctx._bufs[0], 
            offset, 
            mask, 
            ctx._bufs[1],
            grad_input, 
            grad_weight, 
            grad_bias, 
            grad_offset, 
            grad_mask,
            grad_output, 
            weight.shape[2], 
            weight.shape[3], 
            ctx.stride,
            ctx.stride, 
            ctx.padding, 
            ctx.padding, 
            ctx.dilation, 
            ctx.dilation,
            ctx.groups, 
            ctx.deformable_groups, 
            ctx.with_bias
        )
        if not ctx.with_bias:
            grad_bias = None

        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
                None, None, None, None, None)

    @staticmethod
    def _infer_shape(ctx, input, weight):
        n = input.size(0)
        channels_out = weight.size(0)
        height, width = input.shape[2:4]
        kernel_h, kernel_w = weight.shape[2:4]
        height_out = (height + 2 * ctx.padding -
                      (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
        width_out = (width + 2 * ctx.padding -
                     (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
        return n, channels_out, height_out, width_out

deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply

its usage is like vanilla Conv except another offset input. Like this:

cls_conv = DeformConv(256, 256, self.dcn_kernel, 1, self.dcn_pad)
cls_conv(cls_feat, dcn_offset)
Lyken17 commented 3 years ago

For customized operators, you will need to define the calculation formula by yourself. You can refer to https://github.com/Lyken17/pytorch-OpCounter/blob/master/thop/vision/basic_hooks.py#L24 as an example.

zzzmm1 commented 2 years ago

Hello, have you implemented the custom ops for deformable convolution? Could you please share it? Thanks a lot!