Lilac-Lee / PointNetLK_Revisited

Implementation for our CVPR 2021 oral paper "PointNetLK Revisited".
MIT License
47 stars 5 forks source link

issure in ThreeDMatch_Testing dataset #9

Closed gitouni closed 2 years ago

gitouni commented 2 years ago

In data_utils.py, as find_voxel_overlaps(p0, p1, voxel) used in class ThreeDMatch_Testing , the additional prior knowledge has been added into test dataloader because the overlap relationship of voxels is acquired in registrated point cloud pairs (p0_pre, p1_pre). Noted that the perturbation (random transform) are carried out after this operation, the whole alogrithm cannot acquire overlap relationship in unregistrated pairs. Some relevant codes are shown as follows:


def find_voxel_overlaps(p0, p1, voxel):
    xmin, ymin, zmin = np.max(np.stack([np.min(p0, 0), np.min(p1, 0)]), 0)
    xmax, ymax, zmax = np.min(np.stack([np.max(p0, 0), np.max(p1, 0)]), 0)

    # truncate the point cloud
    eps = 1e-6
    p0_ = p0[np.all(p0>[xmin+eps,ymin+eps,zmin+eps], axis=1) & np.all(p0<[xmax-eps,ymax-eps,zmax-eps], axis=1)]
    p1_ = p1[np.all(p1>[xmin+eps,ymin+eps,zmin+eps], axis=1) & np.all(p1<[xmax-eps,ymax-eps,zmax-eps], axis=1)]

    # recalculate the constraints
    xmin, ymin, zmin = np.max(np.stack([np.min(p0, 0), np.min(p1, 0)]), 0)
    xmax, ymax, zmax = np.min(np.stack([np.max(p0, 0), np.max(p1, 0)]), 0)
    vx = (xmax - xmin) / voxel
    vy = (ymax - ymin) / voxel
    vz = (zmax - zmin) / voxel

    return p0_, p1_, xmin, ymin, zmin, xmax, ymax, zmax, vx, vy, vz

class ThreeDMatch_Testing(torch.utils.data.Dataset):
    def __init__(self, dataset_path, category, overlap_ratio, voxel_ratio, voxel, max_voxel_points, num_voxels, rigid_transform, vis):
        self.dataset_path = dataset_path
        self.pairs = []
        with open(category, 'r') as fi:
            cinfo_fi = fi.read().split()   # category names
            for i in range(len(cinfo_fi)):
                cat_name = cinfo_fi[i]
                cinfo_name = cat_name + '*%.2f.txt' % overlap_ratio
                cinfo = glob.glob(os.path.join(self.dataset_path, cinfo_name))
                for fi_name in cinfo:
                    with open(fi_name) as fi:
                        fi_list = [x.strip().split() for x in fi.readlines()]
                    for fi in fi_list:
                        self.pairs.append([fi[0], fi[1]])

        self.voxel_ratio = voxel_ratio
        self.voxel = int(voxel)
        self.max_voxel_points = max_voxel_points
        self.num_voxels = num_voxels
        self.perturbation = load_pose(rigid_transform, len(self.pairs))
        self.vis = vis

    def __len__(self):
        return len(self.pairs)

    def do_transform(self, p0, x):
        # p0: [N, 3]
        # x: [1, 6], twist-params
        g = utils.exp(x).to(p0) # [1, 4, 4]
        p1 = utils.transform(g, p0)
        igt = g.squeeze(0) # igt: p0 -> p1
        return p1, igt

    def __getitem__(self, index):
        p0_pre, p1_pre = load_3dmatch_batch_data(os.path.join(self.dataset_path, self.pairs[index][0]), os.path.join(self.dataset_path, self.pairs[index][1]), self.voxel_ratio)

        # voxelization
        p0, p1, xmin, ymin, zmin, xmax, ymax, zmax, vx, vy, vz = find_voxel_overlaps(p0_pre, p1_pre, self.voxel)   # constraints of P1 ^ P2, where contains roughly overlapped area
        voxels_p0, coords_p0, num_points_per_voxel_p0 = points_to_voxel_second(p0, (xmin, ymin, zmin, xmax, ymax, zmax), 
                        (vx, vy, vz), self.max_voxel_points, reverse_index=False, max_voxels=self.num_voxels)
        voxels_p1, coords_p1, num_points_per_voxel_p1 = points_to_voxel_second(p1, (xmin, ymin, zmin, xmax, ymax, zmax), 
                        (vx, vy, vz), self.max_voxel_points, reverse_index=False, max_voxels=self.num_voxels)

        coords_p0_idx = coords_p0[:,1]*(int(self.voxel**2)) + coords_p0[:,0]*(int(self.voxel)) + coords_p0[:,2]
        coords_p1_idx = coords_p1[:,1]*(int(self.voxel**2)) + coords_p1[:,0]*(int(self.voxel)) + coords_p1[:,2]

        # calculate for the voxel medium
        xm_x = np.linspace(xmin+vx/2, xmax-vx/2, int(self.voxel))
        xm_y = np.linspace(ymin+vy/2, ymax-vy/2, int(self.voxel))
        xm_z = np.linspace(zmin+vz/2, zmax-vz/2, int(self.voxel))
        mesh3d = np.vstack(np.meshgrid(xm_x,xm_y,xm_z)).reshape(3,-1).T
        voxel_coords_p0 = mesh3d[coords_p0_idx]
        voxel_coords_p1 = mesh3d[coords_p1_idx]

        # find voxels where number of points >= 90% of the maximum number of points
        idx_conditioned_p0 = coords_p0_idx[np.where(num_points_per_voxel_p0>=0.1*self.max_voxel_points)]
        idx_conditioned_p1 = coords_p1_idx[np.where(num_points_per_voxel_p1>=0.1*self.max_voxel_points)]
        idx_conditioned, _, _ = np.intersect1d(idx_conditioned_p0, idx_conditioned_p1, assume_unique=True, return_indices=True)
        _, _, idx_p0 = np.intersect1d(idx_conditioned, coords_p0_idx, assume_unique=True, return_indices=True)
        _, _, idx_p1 = np.intersect1d(idx_conditioned, coords_p1_idx, assume_unique=True, return_indices=True)
        voxel_coords_p0 = voxel_coords_p0[idx_p0]
        voxel_coords_p1 = voxel_coords_p1[idx_p1]
        voxels_p0 = voxels_p0[idx_p0]
        voxels_p1 = voxels_p1[idx_p1]

        x = torch.from_numpy(self.perturbation[index][np.newaxis,...])
        voxels_p1_, igt = self.do_transform(torch.from_numpy(voxels_p1.reshape(-1,3)), x)
        voxels_p1 = voxels_p1_.reshape(voxels_p1.shape)
        voxel_coords_p1, _ = self.do_transform(torch.from_numpy(voxel_coords_p1).double(), x)
        p1, _ = self.do_transform(torch.from_numpy(p1), x)

        if self.vis:
            return voxels_p0, voxel_coords_p0, voxels_p1, voxel_coords_p1, igt, p0, p1
        else:    
            return voxels_p0, voxel_coords_p0, voxels_p1, voxel_coords_p1, igt
Lilac-Lee commented 2 years ago

Hi, thanks for your comments. Yes, the current voxelization setting still relies on this information. I have provided the code for voxelization after transformation, and updated data_utils.py, models.py, test.py, and README. Cheers.

gitouni commented 2 years ago

Thanks for your reply and rectification! Cheers.