jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

Incorrect gardients for Batch Size > 1 #1181

Closed hzxie closed 5 years ago

hzxie commented 5 years ago

Let the data be as following:

x = np.array([[[-0.259665, 0.0118517, -0.243257],
             [0.193962, 0.0118547, 0.0201565],
             [-0.0122578, 0.16738, 0.240322],
             [-0.0158106, 0.208632, 0.28453],
             [-0.181615, -0.16903, -0.260208],
             [0.0602822, 0.299544, 0.146627],
             [0.114829, -0.0579054, -0.107894],
             [0.204454, -0.208405, -0.111558],
             [-0.107981, 0.0350579, -0.0769472],
             [-0.0186282, 0.276509, -0.207822]],
             [[-0.307163, 0.203857, 0.137906],
             [-0.0558089, -0.22785, 0.120466],
             [0.0676829, -0.118868, 0.286869],
             [-0.0710349, 0.205583, 0.0285916],
             [0.101161, 0.13191, 0.0625195],
             [0.214593, -0.179271, -0.243364],
             [-0.235938, 0.0153024, 0.162884],
             [0.0704157, 0.0388867, -0.189852],
             [-0.0374643, 0.110724, 0.274795],
             [0.296245, -0.0113578, 0.313378]]])
y = np.array([[[0.251954, -0.207532, -0.285618],
             [-0.214598, -0.0834898, -0.3015],
             [-0.110343, 0.132818, -0.0156324],
             [0.0895449, 0.264934, 0.174155],
             [-0.29826, -0.0372416, -0.0834365],
             [0.0570304, 0.29785, -0.143989],
             [0.0748554, 0.121849, 0.0984506],
             [-0.0476878, 0.234047, -0.248025],
             [-0.1766, -0.0811705, 0.151986],
             [-0.0363526, -0.285259, 0.219628]],
             [[0.157594, -0.117827, -0.186256],
             [0.249807, 0.086539, 0.207432],
             [-0.051006, 0.0569197, -0.195966],
             [-0.196592, -0.284987, 0.160985],
             [-0.177391, 0.101723, -0.265756],
             [-0.117553, -0.274086, 0.13629],
             [-0.264577, 0.0443226, 0.221418],
             [-0.220708, 0.242977, -0.183755],
             [0.307753, -0.0836118, 0.113444],
             [-0.105519, 0.280715, 0.258871]]])

We define Chamfer Distance as follows:

def chamfer_distance(ptcloud1, ptcloud2):
    num, n_points, _ = ptcloud1.shape
    xx = np.matmul(ptcloud1, ptcloud1.transpose((0, 2, 1)))
    yy = np.matmul(ptcloud2, ptcloud2.transpose((0, 2, 1)))
    xy = np.matmul(ptcloud1, ptcloud2.transpose((0, 2, 1)))
    diag = np.arange(n_points)
    rx = xx[:, diag, diag][:, np.newaxis, :].repeat(axis=1, repeats=n_points)
    ry = yy[:, diag, diag][:, np.newaxis, :].repeat(axis=1, repeats=n_points)
    dist = (rx.transpose((0, 2, 1)) + ry - 2 * xy)
    dist1 = np.amin(dist, 1)[0]
    dist2 = np.amin(dist, 2)[0]
    return np.mean(dist1 + dist2)

The gradient for the function can be inferred as:

_grad1 = grad(chamfer_distance, 0)
_grad1(x, y)
_grad2 = grad(chamfer_distance, 1)
_grad2(x, y)

However, the results seem not correct:

DeviceArray([[[-0.0012944 ,  0.02888696, -0.02031551],
              [ 0.02382132, -0.02199886, -0.01565882],
              [-0.03778318, -0.01040459,  0.04160768],
              [-0.0210711 , -0.0112604 ,  0.022075  ],
              [ 0.0131932 , -0.03421608,  0.01651679],
              [-0.01170508,  0.01384401, -0.0110112 ],
              [-0.027425  ,  0.02992532,  0.0355448 ],
              [ 0.02916133,  0.01502161,  0.0033868 ],
              [ 0.0146686 , -0.01585836, -0.07031256],
              [-0.00350788,  0.01271659,  0.0033146 ]],

             [[ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ]]], dtype=float32)

DeviceArray([[[ 0.046425  , -0.02957612, -0.1051688 ],
              [-0.00417981,  0.01514778, -0.0281654 ],
              [-0.0009448 ,  0.03910404,  0.02452592],
              [ 0.05313672,  0.0169272 , -0.02429721],
              [-0.007719  , -0.00981866,  0.0319641 ],
              [ 0.01513172,  0.00426821,  0.0127666 ],
              [-0.00639868,  0.01289266, -0.01271546],
              [-0.01162384, -0.01698479, -0.0160812 ],
              [-0.0137238 , -0.02324568,  0.04578664],
              [-0.04816132, -0.0153708 ,  0.0662372 ]],

             [[ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ]]], dtype=float32)
mattjj commented 5 years ago

Thanks for asking this!

The easiest way to check derivatives is against numerical differences. We have a handy utility you can use like this:

from jax.test_util import check_grads
check_grads(chamfer_distance, (x, y), order=1, modes=["rev"])

That just checks against numerical differences in random directions. It'll test higher-order and mixed forward- and reverse-mode too.

It works on this example, as does this manual numerical check:

In [5]: import numpy as onp

In [6]: vec  = onp.random.RandomState(0).randn(*x.shape)

In [7]: vec = vec / onp.linalg.norm(vec)

In [8]: (chamfer_distance(x + 1e-5 * vec, y) - chamfer_distance(x, y)) / 1e-5
Out[8]: DeviceArray(0.00186265, dtype=float32)

In [9]: onp.vdot(_grad1(x, y), vec)
Out[9]: 0.0018528942794496933

So I think the gradient must be technically correct in some sense, but perhaps something unexpected is still going on.

Is it the zeros that are surprising? I haven't actually looked at the code yet; I just wanted to do these numerical checks, treating the code as a black box. I'm sure we can figure out what's surprising though.

hzxie commented 5 years ago

Here are the results on a real experiment. I believe there MUST BE something wrong with your code. Screenshot from 2019-08-15 09-15-21

The results show that the gradients for the batch > 1 are zeros, which is impossible.

hzxie commented 5 years ago

Here's the implementation:

import caffe
import jax.numpy as np

from jax import grad, jit

class ChamferLoss(caffe.Layer):
    def setup(self, bottom, top):
        # check input pair
        if len(bottom) != 2:
            raise Exception("Need two inputs to compute distance.")

    def reshape(self, bottom, top):
        # check input dimensions match
        if bottom[0].count != bottom[1].count:
            raise Exception("Inputs must have the same dimension.")

        # difference is shape of inputs
        self.grad0 = jit(grad(self._chamfer_distance, 0))
        self.grad1 = jit(grad(self._chamfer_distance, 1))
        # loss output is scalar
        top[0].reshape(1)

    def forward(self, bottom, top):
        top[0].data[0] = float(self._chamfer_distance(bottom[0].data, bottom[1].data))

    def _chamfer_distance(self, ptcloud1, ptcloud2):
        num, n_points, _ = ptcloud1.shape
        xx = np.matmul(ptcloud1, ptcloud1.transpose((0, 2, 1)))
        yy = np.matmul(ptcloud2, ptcloud2.transpose((0, 2, 1)))
        xy = np.matmul(ptcloud1, ptcloud2.transpose((0, 2, 1)))

        diag = np.arange(n_points)
        rx = xx[:, diag, diag][:, np.newaxis, :].repeat(axis=1, repeats=n_points)
        ry = yy[:, diag, diag][:, np.newaxis, :].repeat(axis=1, repeats=n_points)
        dist = (rx.transpose((0, 2, 1)) + ry - 2 * xy)
        dist1 = np.amin(dist, 1)[0]
        dist2 = np.amin(dist, 2)[0]
        return np.mean(dist1 + dist2)

    def backward(self, top, propagate_down, bottom):
        _grad0 = self.grad0(bottom[0].data, bottom[1].data)
        _grad1 = self.grad1(bottom[0].data, bottom[1].data)

        bottom[0].diff[...] = _grad0
        bottom[1].diff[...] = _grad1

Is there anything wrong with the code?

hzxie commented 5 years ago

I compre the output with Pytorch's autograd. The results of JAX are not same as PyTorch's.

The code for PyTorch are as follows:

xt = torch.rand((2, 10, 3), requires_grad=True)
yt = torch.rand((2, 10, 3), requires_grad=True)

bs, num_points, points_dim = xt.size()
xx = torch.bmm(xt, xt.transpose(2,1))
yy = torch.bmm(yt, yt.transpose(2,1))
zz = torch.bmm(xt, yt.transpose(2,1))
diag_ind = torch.arange(0, num_points).type(torch.cuda.LongTensor)
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2,1) + ry - 2*zz)
loss = torch.mean(P.min(1)[0]) + torch.mean(P.min(2)[0])

loss.backward(xt)
print(xt.grad)
hzxie commented 5 years ago

In my case, the values x are y are as following:

x = np.array([[[0.4011, 0.2908, 0.3521],
         [0.0861, 0.8777, 0.4353],
         [0.7420, 0.1549, 0.8189],
         [0.3323, 0.4617, 0.2126],
         [0.2867, 0.9608, 0.6667],
         [0.1773, 0.5807, 0.0388],
         [0.4252, 0.6412, 0.6042],
         [0.7260, 0.5044, 0.8364],
         [0.7606, 0.0383, 0.2995],
         [0.3635, 0.5297, 0.2056]],

        [[0.2465, 0.2190, 0.4713],
         [0.2807, 0.0673, 0.5035],
         [0.1563, 0.4924, 0.3246],
         [0.0118, 0.4919, 0.7165],
         [0.3805, 0.4503, 0.0220],
         [0.6650, 0.7816, 0.0317],
         [0.5174, 0.9315, 0.3767],
         [0.6830, 0.9195, 0.3598],
         [0.1752, 0.6147, 0.8423],
         [0.9798, 0.8040, 0.7357]]])
y = np.array([[[0.9321, 0.7667, 0.2552],
         [0.7745, 0.1306, 0.2459],
         [0.2194, 0.3606, 0.9956],
         [0.4877, 0.2506, 0.1692],
         [0.5293, 0.4913, 0.3616],
         [0.6573, 0.5852, 0.2728],
         [0.0400, 0.3716, 0.4207],
         [0.6753, 0.2750, 0.8677],
         [0.9813, 0.7300, 0.1378],
         [0.7216, 0.7457, 0.4421]],

        [[0.3020, 0.9361, 0.7699],
         [0.5319, 0.1045, 0.0440],
         [0.5918, 0.5900, 0.8566],
         [0.5447, 0.0098, 0.4900],
         [0.8860, 0.1312, 0.6319],
         [0.0322, 0.6574, 0.8497],
         [0.0314, 0.3995, 0.6097],
         [0.1369, 0.4054, 0.4159],
         [0.1296, 0.2437, 0.9193],
         [0.5789, 0.0539, 0.0890]]])

The gradients outputted by PyTorch are:

tensor([[[-0.4864,  0.2258,  1.0262],
         [ 0.1293,  1.4202,  0.0410],
         [ 0.3745, -0.6739, -0.2740],
         [ 0.2679,  0.1694, -1.0020],
         [-1.2204,  0.6036,  0.6303],
         [ 0.3855,  0.5868, -1.0716],
         [-0.5460,  0.9142,  0.0374],
         [ 0.1422,  0.6437, -0.0880],
         [-0.0782, -0.5182,  0.3007],
         [-5.0846, -1.1675, -1.0131]],

        [[ 0.3075, -0.5229,  0.1554],
         [-3.1811,  0.1435, -0.2846],
         [ 0.1089,  0.4886, -0.5124],
         [-0.4412,  1.2147,  0.0302],
         [-1.4068,  3.0532, -0.3112],
         [ 0.3735,  1.9002, -0.0343],
         [ 0.6046, -0.0127, -1.1035],
         [ 1.0692, -0.0465, -1.1509],
         [-0.7224, -1.0715,  0.1221],
         [ 1.0890,  0.6006, -0.3394]]])

While Jax outputs:

>>> _grad1 = grad(chamfer_distance, 0)
>>> _grad1(x, y)
DeviceArray([[[-0.03463998,  0.01608   ,  0.07316001],
              [ 0.00922   ,  0.10122   ,  0.00291999],
              [ 0.02667999, -0.04804001, -0.01951998],
              [ 0.01906002,  0.0121    , -0.07142001],
              [-0.08698   ,  0.04302   ,  0.04492   ],
              [ 0.02746   ,  0.04182   , -0.07638   ],
              [-0.03893998,  0.06520003,  0.00266001],
              [ 0.01014   ,  0.04588   , -0.00625999],
              [-0.00556001, -0.03692   ,  0.02143999],
              [-0.36236   , -0.08320004, -0.0722    ]],

             [[ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ],
              [ 0.        ,  0.        ,  0.        ]]], dtype=float32)
hzxie commented 5 years ago

My fault.

mattjj commented 5 years ago

Glad to hear you figured it out!