filaPro / oneformer3d

[CVPR2024] OneFormer3D: One Transformer for Unified Point Cloud Segmentation
Other
351 stars 32 forks source link

i want to know how to use predict function in oneformer3d.py,how to inference/predict? #91

Closed shaonianG closed 1 day ago

shaonianG commented 2 weeks ago

thank you for your oustanding work,now i want to have a visual result about segment,so i have changed the code in oneformer3d.py follow #57,but i didn't find how to use oneformer3d to inference in readme,also i dont know where this predict function can be used. i check the runner,didnt find how to predict,only train,val and test.can you help me? any advice will be appreciate

filaPro commented 1 week ago

i think, you can run test.py and put debugger at the predict function and see what happens

shaonianG commented 1 day ago

i think, you can run test.py and put debugger at the predict function and see what happens

thank you for your reply,i run the test.py and find predict function in https://github.com/filaPro/oneformer3d/blob/main/oneformer3d/oneformer3d.py#L776 add following code in there,instead of line 776,this code is from #57

    pred_pts_seg = batch_data_samples[0].pred_pts_seg
    instance_labels = pred_pts_seg.instance_labels  # tensor, (num_instance,)
    instance_scores = pred_pts_seg.instance_scores  # tensor, (num_instance,)
    pts_instance_mask = pred_pts_seg.pts_instance_mask[0]  # tensor, (num_instances, num_points)
    input_points = batch_inputs_dict["points"][0]  # tensor, (num_points, xyzrgb)
    input_point_name = batch_data_samples[0].lidar_path.split('/')[-1].split('.')[0]
    def save_point_cloud(points, file_path, colors=None):
        """Save point cloud with optional color."""
        if isinstance(points, torch.Tensor):
            points = points.cpu().numpy()  # Convert tensor to NumPy array on the CPU
        points = np.asarray(points, dtype=np.float32)  # Ensure points are in float32 format
        pc = o3d.geometry.PointCloud()
        pc.points = o3d.utility.Vector3dVector(points[:, :3])

        # If colors are provided, apply them; otherwise, use default gray
        if colors is not None:
            pc.colors = o3d.utility.Vector3dVector(colors)
        else:
            pc.colors = o3d.utility.Vector3dVector(np.tile([0.5, 0.5, 0.5], (points.shape[0], 1)))  # Default gray

        o3d.io.write_point_cloud(file_path, pc)

    def filter_and_save_instances(instance_labels, instance_scores, pts_instance_mask, input_points,
                                  input_point_name, threshold=0.5):

        base_dir = f"./work_dirs/color/{input_point_name}"
        if not os.path.exists(base_dir):
            os.makedirs(base_dir)
        input_pc_path = os.path.join(base_dir, f"{input_point_name}.ply")
        save_point_cloud(input_points, input_pc_path)

        instance_count = {}

        colors = np.zeros((input_points.shape[0], 3))  # Initialize color array for all points

        # We can use a fixed color map or random colors for each instance label
        unique_labels = np.unique(instance_labels)  # Get unique instance labels

        label_color_map = {}
        for i, label in enumerate(unique_labels):
            label_color_map[label] = np.random.rand(3)  # Random color for each unique label

        for i in range(len(instance_scores)):
            if instance_scores[i] >= threshold:
                label = instance_labels[i]
                if label not in instance_count:
                    instance_count[label] = 0
                instance_count[label] += 1

                instance_mask = pts_instance_mask[i].astype(bool)
                instance_points = input_points[instance_mask]

                # Assign color to the points based on their label
                instance_color = label_color_map[label]
                colors[instance_mask] = instance_color

                # Save instance-specific point cloud with color
                instance_pc_path = os.path.join(base_dir,
                                                f"{input_point_name}_{label}_{instance_count[label]}_{instance_scores[i]:.2f}.ply")
                save_point_cloud(instance_points, instance_pc_path, colors=colors[instance_mask])

         # Optionally, you can also save a point cloud with all instances colored
        save_point_cloud(input_points, input_pc_path, colors=colors)

    filter_and_save_instances(instance_labels, instance_scores, pts_instance_mask, input_points, input_point_name)
    return batch_data_samples

then run test.py, will generate ply file in /workspace/work_dir/color put it in meshlab, can get the visual result

截屏2024-11-26 上午10 07 58