I've found that using torch.bmm leads to inaccurate computations of the pairwise distances, sometimes resulting in negative chamfer distances. My solution has been to change lines 23-32 in loss.py to:
xx = x.pow(2).sum(dim=-1)
yy = y.pow(2).sum(dim=-1)
zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand_as(zz.transpose(2, 1))
ry = yy.unsqueeze(1).expand_as(zz)
I didn't see much difference in speed. The performance was slightly more stable due to no negative chamfer distances in P. The final performance was unchanged.
Another solution is to calculate the loss in doubles, but that's memory intensive and inefficient.
I've found that using torch.bmm leads to inaccurate computations of the pairwise distances, sometimes resulting in negative chamfer distances. My solution has been to change lines 23-32 in loss.py to:
I didn't see much difference in speed. The performance was slightly more stable due to no negative chamfer distances in P. The final performance was unchanged. Another solution is to calculate the loss in doubles, but that's memory intensive and inefficient.