NVlabs / pacnet

Pixel-Adaptive Convolutional Neural Networks (CVPR '19)
https://suhangpro.github.io/pac/
Other
511 stars 79 forks source link

Question about Transpose #20

Open fkokkinos opened 4 years ago

fkokkinos commented 4 years ago

I would like to ask whether the transpose operator PacConvTranspose2d is a indeed the transpose operator of PacConv2d or a generic spatially-varying upsampling operator. A transpose/adjoint operator should pass the adjoint test https://en.wikipedia.org/wiki/Hermitian_adjoint, however it is not possible to reproduce the test with success. Based on your test file, I created the following tests:

    def _allclose(x1, x2, rtol=1e-5, atol=1e-10):
        assert np.allclose(x1.cpu(), x2.cpu(), rtol=rtol, atol=atol)

    @repeat_impl_types
    def test_adjoint_const_kernel_th(self, native_impl):
        bs, sz, k_ch = 1, 128, 5
        args = dict(in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2, dilation=1)
        k_with_d = (args['kernel_size'] - 1) * args['dilation'] + 1
        im_1 = th.rand(bs, args['in_channels'], sz, sz).to(self.device)
        im_2 = th.rand(bs, args['in_channels'], sz, sz).to(self.device)

        conv_w = th.rand(args['in_channels'], args['out_channels'],
                         args['kernel_size'], args['kernel_size']).to(self.device)
        conv_b = th.zeros(args['out_channels']).to(self.device)
        conv_th = nn.Conv2d(**args).to(self.device)
        conv_t_th = nn.ConvTranspose2d(**args).to(self.device)
        conv_th.weight.data[:] = conv_t_th.weight.data[:] = conv_w
        conv_th.bias.data[:] = conv_t_th.bias.data[:] = conv_b
        res1 = conv_th(im_1).detach().reshape(-1)
        res2 = conv_t_th(im_2).detach().reshape(-1)
        _allclose(res1.dot(im_2.reshape(-1)).detach(), res2.dot(im_1.reshape(-1)).detach())

    @repeat_impl_types
    def test_adjoint_const_kernel_pac(self, native_impl):
        bs, sz, k_ch = 1, 6, 5
        args = dict(in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2, dilation=1)
        k_with_d = (args['kernel_size'] - 1) * args['dilation'] + 1
        sz_out = (sz - 1) * args['stride'] - 2 * args['padding'] + k_with_d + 0 #args['output_padding']
        im_1 = th.rand(bs, args['in_channels'], sz, sz).to(self.device)
        im_2 = th.rand(bs, args['in_channels'], sz, sz).to(self.device)
        im_k = th.rand(bs, k_ch, sz_out, sz_out).to(self.device)
        conv_w = th.rand(args['in_channels'], args['out_channels'],
                         args['kernel_size'], args['kernel_size']).to(self.device)
        conv_b = th.zeros(args['out_channels']).to(self.device)
        conv = pac.PacConv2d(native_impl=native_impl, **args).to(self.device)
        conv_t = pac.PacConvTranspose2d(native_impl=native_impl, **args).to(self.device)
        conv.weight.data[:] = conv_t.weight.data[:] = conv_w
        conv.bias.data[:] = conv_t.bias.data[:] = conv_b
        res1 = conv(im_1, im_k).detach().reshape(-1)
        res2 = conv_t(im_2, im_k).detach().reshape(-1)

        _allclose(res1.dot(im_2.reshape(-1)).detach(), res2.dot(im_1.reshape(-1)).detach())

The Pytorch implementations of Conv2D and ConvTranspose2d pass the test with success, however PacConv2d and PacConvTranspose2d fail to pass the test.

Best regards, Filippos