Closed hzxie closed 5 years ago
Thanks for asking this!
The easiest way to check derivatives is against numerical differences. We have a handy utility you can use like this:
from jax.test_util import check_grads
check_grads(chamfer_distance, (x, y), order=1, modes=["rev"])
That just checks against numerical differences in random directions. It'll test higher-order and mixed forward- and reverse-mode too.
It works on this example, as does this manual numerical check:
In [5]: import numpy as onp
In [6]: vec = onp.random.RandomState(0).randn(*x.shape)
In [7]: vec = vec / onp.linalg.norm(vec)
In [8]: (chamfer_distance(x + 1e-5 * vec, y) - chamfer_distance(x, y)) / 1e-5
Out[8]: DeviceArray(0.00186265, dtype=float32)
In [9]: onp.vdot(_grad1(x, y), vec)
Out[9]: 0.0018528942794496933
So I think the gradient must be technically correct in some sense, but perhaps something unexpected is still going on.
Is it the zeros that are surprising? I haven't actually looked at the code yet; I just wanted to do these numerical checks, treating the code as a black box. I'm sure we can figure out what's surprising though.
Here are the results on a real experiment. I believe there MUST BE something wrong with your code.
The results show that the gradients for the batch > 1 are zeros, which is impossible.
Here's the implementation:
import caffe
import jax.numpy as np
from jax import grad, jit
class ChamferLoss(caffe.Layer):
def setup(self, bottom, top):
# check input pair
if len(bottom) != 2:
raise Exception("Need two inputs to compute distance.")
def reshape(self, bottom, top):
# check input dimensions match
if bottom[0].count != bottom[1].count:
raise Exception("Inputs must have the same dimension.")
# difference is shape of inputs
self.grad0 = jit(grad(self._chamfer_distance, 0))
self.grad1 = jit(grad(self._chamfer_distance, 1))
# loss output is scalar
def forward(self, bottom, top):
top[0].data[0] = float(self._chamfer_distance(bottom[0].data, bottom[1].data))
def _chamfer_distance(self, ptcloud1, ptcloud2):
num, n_points, _ = ptcloud1.shape
xx = np.matmul(ptcloud1, ptcloud1.transpose((0, 2, 1)))
yy = np.matmul(ptcloud2, ptcloud2.transpose((0, 2, 1)))
xy = np.matmul(ptcloud1, ptcloud2.transpose((0, 2, 1)))
diag = np.arange(n_points)
rx = xx[:, diag, diag][:, np.newaxis, :].repeat(axis=1, repeats=n_points)
ry = yy[:, diag, diag][:, np.newaxis, :].repeat(axis=1, repeats=n_points)
dist = (rx.transpose((0, 2, 1)) + ry - 2 * xy)
dist1 = np.amin(dist, 1)[0]
dist2 = np.amin(dist, 2)[0]
return np.mean(dist1 + dist2)
def backward(self, top, propagate_down, bottom):
_grad0 = self.grad0(bottom[0].data, bottom[1].data)
_grad1 = self.grad1(bottom[0].data, bottom[1].data)
bottom[0].diff[...] = _grad0
bottom[1].diff[...] = _grad1
Is there anything wrong with the code?
I compre the output with Pytorch's autograd. The results of JAX are not same as PyTorch's.
The code for PyTorch are as follows:
xt = torch.rand((2, 10, 3), requires_grad=True)
yt = torch.rand((2, 10, 3), requires_grad=True)
bs, num_points, points_dim = xt.size()
xx = torch.bmm(xt, xt.transpose(2,1))
yy = torch.bmm(yt, yt.transpose(2,1))
zz = torch.bmm(xt, yt.transpose(2,1))
diag_ind = torch.arange(0, num_points).type(torch.cuda.LongTensor)
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2,1) + ry - 2*zz)
loss = torch.mean(P.min(1)[0]) + torch.mean(P.min(2)[0])
In my case, the values x
are y
are as following:
x = np.array([[[0.4011, 0.2908, 0.3521],
[0.0861, 0.8777, 0.4353],
[0.7420, 0.1549, 0.8189],
[0.3323, 0.4617, 0.2126],
[0.2867, 0.9608, 0.6667],
[0.1773, 0.5807, 0.0388],
[0.4252, 0.6412, 0.6042],
[0.7260, 0.5044, 0.8364],
[0.7606, 0.0383, 0.2995],
[0.3635, 0.5297, 0.2056]],
[[0.2465, 0.2190, 0.4713],
[0.2807, 0.0673, 0.5035],
[0.1563, 0.4924, 0.3246],
[0.0118, 0.4919, 0.7165],
[0.3805, 0.4503, 0.0220],
[0.6650, 0.7816, 0.0317],
[0.5174, 0.9315, 0.3767],
[0.6830, 0.9195, 0.3598],
[0.1752, 0.6147, 0.8423],
[0.9798, 0.8040, 0.7357]]])
y = np.array([[[0.9321, 0.7667, 0.2552],
[0.7745, 0.1306, 0.2459],
[0.2194, 0.3606, 0.9956],
[0.4877, 0.2506, 0.1692],
[0.5293, 0.4913, 0.3616],
[0.6573, 0.5852, 0.2728],
[0.0400, 0.3716, 0.4207],
[0.6753, 0.2750, 0.8677],
[0.9813, 0.7300, 0.1378],
[0.7216, 0.7457, 0.4421]],
[[0.3020, 0.9361, 0.7699],
[0.5319, 0.1045, 0.0440],
[0.5918, 0.5900, 0.8566],
[0.5447, 0.0098, 0.4900],
[0.8860, 0.1312, 0.6319],
[0.0322, 0.6574, 0.8497],
[0.0314, 0.3995, 0.6097],
[0.1369, 0.4054, 0.4159],
[0.1296, 0.2437, 0.9193],
[0.5789, 0.0539, 0.0890]]])
The gradients outputted by PyTorch are:
tensor([[[-0.4864, 0.2258, 1.0262],
[ 0.1293, 1.4202, 0.0410],
[ 0.3745, -0.6739, -0.2740],
[ 0.2679, 0.1694, -1.0020],
[-1.2204, 0.6036, 0.6303],
[ 0.3855, 0.5868, -1.0716],
[-0.5460, 0.9142, 0.0374],
[ 0.1422, 0.6437, -0.0880],
[-0.0782, -0.5182, 0.3007],
[-5.0846, -1.1675, -1.0131]],
[[ 0.3075, -0.5229, 0.1554],
[-3.1811, 0.1435, -0.2846],
[ 0.1089, 0.4886, -0.5124],
[-0.4412, 1.2147, 0.0302],
[-1.4068, 3.0532, -0.3112],
[ 0.3735, 1.9002, -0.0343],
[ 0.6046, -0.0127, -1.1035],
[ 1.0692, -0.0465, -1.1509],
[-0.7224, -1.0715, 0.1221],
[ 1.0890, 0.6006, -0.3394]]])
While Jax outputs:
>>> _grad1 = grad(chamfer_distance, 0)
>>> _grad1(x, y)
DeviceArray([[[-0.03463998, 0.01608 , 0.07316001],
[ 0.00922 , 0.10122 , 0.00291999],
[ 0.02667999, -0.04804001, -0.01951998],
[ 0.01906002, 0.0121 , -0.07142001],
[-0.08698 , 0.04302 , 0.04492 ],
[ 0.02746 , 0.04182 , -0.07638 ],
[-0.03893998, 0.06520003, 0.00266001],
[ 0.01014 , 0.04588 , -0.00625999],
[-0.00556001, -0.03692 , 0.02143999],
[-0.36236 , -0.08320004, -0.0722 ]],
[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]]], dtype=float32)
My fault.
Glad to hear you figured it out!
Let the data be as following:
We define Chamfer Distance as follows:
The gradient for the function can be inferred as:
However, the results seem not correct: