alexandrosstergiou / SoftPool

[ICCV 2021] Code for approximated exponential maximum pooling
MIT License
288 stars 52 forks source link

softpool_cuda.forward_2d(input.contiguous(), kernel, stride, output) #43

Closed Chenchenwenyu closed 2 years ago

Chenchenwenyu commented 2 years ago

Hi,thanks for your great idea!

I have a question in " softpool_cuda.forward_2d(input.contiguous(), kernel, stride, output) ",I just don't know what this line of code means ?What do you want to do with this line of code?

class CUDA_SOFTPOOL2d(Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, input, kernel=2, stride=None):
        # Create contiguous tensor (if tensor is not contiguous)
        no_batch = False
        if len(input.size()) == 3:
            no_batch = True
            input.unsqueeze_(0)
        B, C, H, W = input.size()
        kernel = _pair(kernel)
        if stride is None:
            stride = kernel
        else:
            stride = _pair(stride)
        oH = (H - kernel[0]) // stride[0] + 1
        oW = (W - kernel[1]) // stride[1] + 1
        output = input.new_zeros((B, C, oH, oW))
        softpool_cuda.forward_2d(input.contiguous(), kernel, stride, output)   #### what does this line mean ?####
        ctx.save_for_backward(input)
        ctx.kernel = kernel
        ctx.stride = stride
        if no_batch:
            return output.squeeze_(0)
        return output

Looking forward to your reply,thanks!

alexandrosstergiou commented 2 years ago

Hi @ChennolongerCai ,

The line is a call to the C++ extension from pytorch/CUDA/softpool_cuda.cpp: https://github.com/alexandrosstergiou/SoftPool/blob/2d2ec6dca10b7683ffd41061a27910d67816bfa5/pytorch/CUDA/softpool_cuda.cpp#L206 This basically binds the C++ softpool2d_forward_cuda function to Python and makes it callable (from the Python script).

If you want to delve deeper, you can also have a look at the PyTorch official (and rather extensive) guide for working with/writing custom C++/CUDA extensions (link).