yanx27 / Pointnet_Pointnet2_pytorch

PointNet and PointNet++ implemented by pytorch (pure python) and on ModelNet, ShapeNet and S3DIS.
MIT License
3.75k stars 904 forks source link

wrong in query_ball_point #209

Open yc-shan opened 2 years ago

yc-shan commented 2 years ago

https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet2_utils.py#L87

group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]

The element of group_idx is not the distance of points,so group_idx.sort doesn't make sense.

The code should be :

sort_dis,group_idx=sqrdists.sort(dim=-1)
group_idx[sort_dis > radius ** 2] = N
group_idx=group_idx[:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
jasonkena commented 1 year ago

you can also just use PyTorch3D's implementation of the ball query https://github.com/yanx27/Pointnet_Pointnet2_pytorch/issues/178#issuecomment-1587086798

HongqingThomas commented 1 year ago

I think you are right, the original code does not use knn within epsilon ball when there are more than nsample elements in the epsilon ball.