lz1oceani / pointcloud_rl

Apache License 2.0
26 stars 1 forks source link

Visualize pointcloud #2

Open kbkartik opened 2 months ago

kbkartik commented 2 months ago

How do I visualize the pointcloud observation every timestep for dm_control environments? In the o3d_utils.py, there is a method visualize_pcd. But it borrows two methods np2pcd and to_o3d which hasn't been defined.

How can I save the pointcloud observation every timestep as frames of a video for visualization?

lz1oceani commented 2 months ago

You need to first get image frames, and then use imageio to save a video.

import numpy as np
import open3d as o3d
import trimesh
from pyrl.utils.data import is_pcd

def is_o3d(x):
    return isinstance(x, (o3d.geometry.TriangleMesh, o3d.geometry.PointCloud, o3d.geometry.OrientedBoundingBox, o3d.geometry.AxisAlignedBoundingBox))

def to_o3d(x):
    """
    Numpy support is for pcd!
    """
    if is_o3d(x):
        return x
    elif isinstance(x, np.ndarray):
        assert is_pcd(x)
        return o3d.geometry.PointCloud(o3d.utility.Vector3dVector(x))
    elif isinstance(x, trimesh.Trimesh):
        return o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(x.vertices), o3d.utility.Vector3iVector(x.faces))
    elif isinstance(x, trimesh.points.PointCloud):
        return o3d.geometry.PointCloud(x.vertices)
    else:
        print(type(x))
        raise NotImplementedError()

def np2pcd(points, colors=None, normals=None):
    """Convert numpy array to open3d PointCloud."""
    if isinstance(points, o3d.geometry.PointCloud):
        return points
    pc = o3d.geometry.PointCloud()
    pc.points = o3d.utility.Vector3dVector(points.copy())
    if colors is not None:
        colors = np.array(colors)
        if colors.ndim == 2:
            assert len(colors) == len(points)
        elif colors.ndim == 1:
            colors = np.tile(colors, (len(points), 1))
        else:
            raise RuntimeError(colors.shape)
        pc.colors = o3d.utility.Vector3dVector(colors)
    if normals is not None:
        assert len(points) == len(normals)
        pc.normals = o3d.utility.Vector3dVector(normals)
    return pc