XiaoshuiHuang / fmr

This repository is the implementation of our CVPR 2020 work: "Feature-metric Registration: A Fast Semi-supervised Approach for Robust Point Cloud Registration without Correspondences"
MIT License
156 stars 20 forks source link

Code that can be viewed visually during testing #5

Closed yangninghua closed 4 years ago

yangninghua commented 4 years ago
"""
create model
Creator: Xiaoshui Huang
Date: 2020-06-19
"""
import open3d
import copy
import pcl
import pcl.pcl_visualization
import random

import torch
import numpy as np
from random import sample

import se_math.se3 as se3
import se_math.invmat as invmat

# visualize the point clouds
def draw_registration_result(source, target, transformation):
    source_o3d = open3d.geometry.PointCloud()
    source_o3d.points = open3d.utility.Vector3dVector(source)
    target_o3d = open3d.geometry.PointCloud()
    target_o3d.points = open3d.utility.Vector3dVector(target)

    source_temp = copy.deepcopy(source_o3d)
    target_temp = copy.deepcopy(target_o3d)
    source_temp.transform(transformation)
    source_o3d.paint_uniform_color([1, 0, 0])
    source_temp.paint_uniform_color([0, 1, 0])
    target_temp.paint_uniform_color([0, 0, 1])

    open3d.visualization.draw_geometries([source_o3d, source_temp, target_temp])

def vis_pair(cloud1, cloud2, rdm=False):
    color1 = [255, 0, 0]
    color2 = [0, 255, 0]
    if rdm:
        color1 = [255, 0, 0]
        color2 = [random.randint(0, 255) for _ in range(3)]
    visualcolor1 = pcl.pcl_visualization.PointCloudColorHandleringCustom(cloud1, color1[0], color1[1], color1[2])
    visualcolor2 = pcl.pcl_visualization.PointCloudColorHandleringCustom(cloud2, color2[0], color2[1], color2[2])
    vs = pcl.pcl_visualization.PCLVisualizering
    vss1 = pcl.pcl_visualization.PCLVisualizering()  # 初始化一个对象,这里是很重要的一步
    vs.AddPointCloud_ColorHandler(vss1, cloud1, visualcolor1, id=b'cloud', viewport=0)
    vs.AddPointCloud_ColorHandler(vss1, cloud2, visualcolor2, id=b'cloud1', viewport=0)
    vs.SetBackgroundColor(vss1, 0, 0, 0)
    #vs.InitCameraParameters(vss1)
    #vs.SetFullScreen(vss1, True)
    # v = True
    while not vs.WasStopped(vss1):
        vs.Spin(vss1)

def vis_triple(cloud1, cloud2, cloud3):
    visualcolor1 = pcl.pcl_visualization.PointCloudColorHandleringCustom(cloud1, 255, 0, 0)
    visualcolor2 = pcl.pcl_visualization.PointCloudColorHandleringCustom(cloud2, 0, 255, 0)
    visualcolor3 = pcl.pcl_visualization.PointCloudColorHandleringCustom(cloud3, 0, 0, 255)
    vs = pcl.pcl_visualization.PCLVisualizering
    vss1 = pcl.pcl_visualization.PCLVisualizering()  # 初始化一个对象,这里是很重要的一步
    vs.AddPointCloud_ColorHandler(vss1, cloud1, visualcolor1, id=b'cloud', viewport=0)
    vs.AddPointCloud_ColorHandler(vss1, cloud2, visualcolor2, id=b'cloud1', viewport=0)
    vs.AddPointCloud_ColorHandler(vss1, cloud3, visualcolor3, id=b'cloud2', viewport=0)
    # v = True
    while not vs.WasStopped(vss1):
        vs.Spin(vss1)

def transform(cloud, transfer):
    result = np.matmul(cloud, transfer[:3, :3].T) + transfer[:3, 3].reshape(-1, 3)
    return result.astype(np.float32)

# visualize the point clouds
def draw_registration_result_pcl(source, target, transformation):
    source_temp = copy.deepcopy(source)
    source_pcl = pcl.PointCloud(source)
    target_pcl = pcl.PointCloud(target)
    source_temp = pcl.PointCloud(transform(source_temp, transformation))

    vis_pair(source_pcl, target_pcl)
    vis_triple(source_pcl, target_pcl, source_temp)

def norm_method2(tensor1, tensor2):
    c1 = torch.max(tensor1, dim=0)[0] - torch.min(tensor1, dim=0)[0]  # [N, D] -> [D]
    c2 = torch.max(tensor2, dim=0)[0] - torch.min(tensor2, dim=0)[0]  # [N, D] -> [D]
    s1 = torch.max(c1)  # -> scalar
    s2 = torch.max(c2)  # -> scalar
    s = max(s1, s2)
    v1 = tensor1 / s
    v2 = tensor2 / s
    return v1 - v1.mean(dim=0, keepdim=True), v2 - v2.mean(dim=0, keepdim=True)

class FMRTest:
    def __init__(self, args):
        self.filename = args.outfile
        self.dim_k = args.dim_k
        self.max_iter = 10  # max iteration time for IC algorithm
        self._loss_type = 1  # see. self.compute_loss()

    def create_model(self):
        # Encoder network: extract feature for every point. Nx1024
        ptnet = PointNet(dim_k=self.dim_k)
        # feature-metric ergistration (fmr) algorithm: estimate the transformation T
        fmr_solver = SolveRegistration(ptnet)
        return fmr_solver

    def evaluate(self, solver, testloader, device):
        solver.eval()
        with open(self.filename, 'w') as fout:
            self.eval_1__header(fout)
            with torch.no_grad():
                for i, data in enumerate(testloader):
                    p0, p1, igt = data  # igt: p0->p1
                    # # compute trans from p1->p0
                    # g = se3.log(igt)  # --> [-1, 6]
                    # igt = se3.exp(-g)  # [-1, 4, 4]

                    if mydata==1:
                        roi = pcl.load('./point_cloud_dataset/roi_index00003.pcd')
                        cur_nn = pcl.load('./point_cloud_dataset/cur_nn_index00003.pcd')
                        roi_torch = torch.from_numpy(np.array(roi))
                        cur_nn_torch = torch.from_numpy(np.array(cur_nn))
                        roi_torch_norm, cur_nn_torch_norm = norm_method2(roi_torch, cur_nn_torch)
                        roi_torch_norm = roi_torch_norm.unsqueeze(0)
                        cur_nn_torch_norm = cur_nn_torch_norm.unsqueeze(0)
                        p1 = roi_torch_norm
                        p0 = cur_nn_torch_norm

                    p0, p1 = self.ablation_study(p0, p1)
                    p0 = p0.to(device)  # template (1, N, 3)
                    p1 = p1.to(device)  # source (1, M, 3)
                    solver.estimate_t(p0, p1, self.max_iter)

                    est_g = solver.g  # (1, 4, 4)

                    ig_gt = igt.cpu().contiguous().view(-1, 4, 4)  # --> [1, 4, 4]
                    g_hat = est_g.cpu().contiguous().view(-1, 4, 4)  # --> [1, 4, 4]

                    if vis_flag == 1:
                        p0_np = p0.cpu().data.numpy()
                        p1_np = p1.cpu().data.numpy()
                        ig_gt_np = ig_gt.cpu().data.numpy()
                        g_hat_gt_np = g_hat.cpu().data.numpy()
                        ig_gt_np = np.squeeze(ig_gt_np)
                        g_hat_gt_np = np.squeeze(g_hat_gt_np)
                        p0_np = np.squeeze(p0_np)
                        p1_np = np.squeeze(p1_np)
                        draw_registration_result_pcl(p1_np, p0_np, g_hat_gt_np)

                    dg = g_hat.bmm(ig_gt)  # if correct, dg == identity matrix.
                    dx = se3.log(dg)  # --> [1, 6] (if corerct, dx == zero vector)
                    dn = dx.norm(p=2, dim=1)  # --> [1]
                    dm = dn.mean()

                    self.eval_1__write(fout, ig_gt, g_hat)
                    print('test, %d/%d, %f' % (i, len(testloader), dm))
yangninghua commented 4 years ago

image

yangninghua commented 4 years ago

The overall point cloud matching is great for the overall point cloud, and the local matching is not very good because the data set has not learned such features.

yangninghua commented 4 years ago

image