octree-nn / ocnn-pytorch

Octree-based 3D Convolutional Neural Networks
MIT License
152 stars 17 forks source link

Visualizing the octree #31

Closed jotix16 closed 6 months ago

jotix16 commented 6 months ago

Thank you for the great library. :)

Is there any script that could be used for visualizing the octree? That would be very useful for debugging purposes!

jotix16 commented 6 months ago

Straight forward when using octree.xyzb and open3d.

dunbar12138 commented 3 months ago

Hi I'm also looking at how to visualize an octree. Would you mind sharing your script of using open3d and octree.xyzb for visualization? Thank you in advance!

harryseely commented 3 months ago

@dunbar12138

Here this should work for you (a bit verbose... sorry I don't have time to clean it up).

from laspy import read import os import open3d as o3d from matplotlib.colors import LinearSegmentedColormap import numpy as np

def get_custom_cmap(type): if type == "earthy":

Define a custom colourmap for matplotlib

    brown = [0.5, 0.25, 0.0]
    green = [0.0, 0.8, 0.0]
    colors = [brown, [0.6, 0.4, 0.2], [0.8, 0.6, 0.4], [0.6, 0.8, 0.4], [0.4, 0.6, 0.2], green]
elif type == "jet":
    # Create a smooth color gradient for the custom colormap
    colors_continuous = [
        [0.0, 0.0, 0.4],  # Dark Blue
        [0.0, 0.0, 0.8],  # Blue
        [0.0, 0.5, 1.0],  # Light Blue
        [0.2, 0.8, 1.0],  # Cyan
        [0.0, 0.8, 0.0],  # Green
        [0.6, 0.8, 0.2],  # Lime Green
        [1.0, 1.0, 0.0],  # Yellow
        [1.0, 0.6, 0.0],  # Orange
        [0.8, 0.0, 0.0],  # Red
        [0.5, 0.0, 0.0],  # Dark Red
    ]
else:
    raise ValueError("Invalid colormap type!")

positions = np.linspace(0, 1, len(colors_continuous))
cmap = LinearSegmentedColormap.from_list("CustomContinuous", list(zip(positions, colors_continuous)))

return cmap

def read_las_to_np(las_fpath, use_ground_points=None, centralize_coords=True): """ IMPORTANT: format of np array storing lidar data returned has the following format: N x C Where N is the numer of points and C is the number of columns The current columns included in the np array by index are: 0 - X 1 - Y 2 - Z 3 - Intensity (Normalized or Raw, depending on argument) 4 - Return Number 5 - Classification 6 - Scan Angle Rank 7 - Number of Returns

*If compute_normals is True
8 - x component of normals
9 - y component of normals
10 - z component of normals

:param las_fpath: filepath to las file
:param normalize_intensity: whether to normalize intensity values
:param use_ground_points: height below which to remove points, specify as None to use no height filter
:param centralize_coords: whether to make all coords relative to center of point cloud (center is 0,0,0)
:param compute_normals: whether to compute normals for each point
:return: point cloud numpy array
"""

# Read LAS for given plot ID
inFile = read(las_fpath)

# Correct for difference in file naming for scan angle (some LAS files call it scan_angle)
try:
    scan_angle = inFile.scan_angle_rank
except AttributeError:
    try:
        scan_angle = inFile.scan_angle
    except AttributeError:
        raise Exception("Issue with scan angle name in LAS file...")

# Get coords coordinates
points = np.vstack([inFile.x,
                    inFile.y,
                    inFile.z,
                    inFile.intensity,
                    inFile.return_number,
                    inFile.classification,
                    scan_angle,
                    inFile.number_of_returns
                    ]).transpose()

if use_ground_points:
    pass
else:
    # Filter the array by dropping all rows with a value of 2 (ground point) in the Classification column (4th)
    points = points[points[:, 4] != 2]

if centralize_coords:
    points[:, 0:3] = points[:, 0:3] - np.mean(points[:, 0:3], axis=0)

# Check for NANs and report
if np.isnan(points).any():
    raise ValueError('NaN values in input point cloud!')
return points

def check_resolution_of_smallest_octant(plot_diameter, max_depth, verbose=True, in_filepath=None, points=None): """ ***NOTE: cannot use normalized point cloud, must have the height of the tallest tree

Function loads a LAS file, and calculates the resolution (i.e., cell size) of the smallest octant.

    Note that since the plot diameter is in meters, the units of the smallest octant resolution are in meters as well.

:param in_filepath: path to las file that has had points centralized yet!
:return: resolution of octree at smallest octant
"""

if points is None:
    xyz = read_las_to_np(in_filepath, centralize_coords=False)[:, 0:3]
else:
    xyz = points[:, 0:3]

# Check max tree height
if np.max(xyz[:, 2]) > plot_diameter:
    cube_length = np.max(xyz[:, 2])
else:
    cube_length = plot_diameter

# Get volume of cube which will contain child leafs
vol_cube = cube_length ** 3

# Calculate the total number of child nodes at the max depth
n_nodes = ((8 ** (max_depth + 1)) - 1) / 7 - 1

# Get the volume of the smallest octant
vol_smallest = vol_cube / n_nodes

# Get the resolution of the smallest octant
res_smallest = vol_smallest ** (1. / 3)

if verbose:
    print(f"The resolution of this octree at a max depth of {max_depth} is {res_smallest}m")

return res_smallest

def plot_octree(in_filepath=None, points=None, cmap=None, octree_depth=8, use_ground_points=True, plot_diameter=22.56): """ :param in_filepath: input filepath :param use_ground_points: whether to use ground points in plot :param fpath: filepath to las for plotting :param points: input point cloud provided instead of filepath, default is None :param cmap: matplotlib colormap :param octree_depth: max octree depth :return: """

if points is None:
    # Load points for a LAS or numpy file
    basename = os.path.basename(in_filepath).split(".")[1]
    if basename == "las":
        points = read_las_to_np(in_filepath, use_ground_points=use_ground_points)
    else:
        print("Filetype not supported!")
        exit()

if in_filepath is None and points is None:
    print("No input point cloud or filepath provided!")
    exit()

# Get xyz
xyz = points[:, 0:3]

# Save as csv
new_file = r"C:\Users\hseely\Downloads\temp_points.ply"

# Pass xyz to Open3D.o3d.geometry.PointCloud and visualize
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz)
o3d.io.write_point_cloud(new_file, pcd)

# Load saved point cloud to o3d
pcd = o3d.io.read_point_cloud(new_file)

# Set the colors of the point cloud
z = xyz[:, 2]
z_norm = (z - np.min(z)) / (np.max(z) - np.min(z))
col_arr = cmap(z_norm)[:, :3]
pcd.colors = o3d.utility.Vector3dVector(col_arr)

# Set up colours
pcd.scale(1 / np.max(pcd.get_max_bound() - pcd.get_min_bound()), center=pcd.get_center())
pcd.colors = o3d.utility.Vector3dVector(col_arr)

# Check size of smallest octant
small_octant_res = check_resolution_of_smallest_octant(points=xyz,
                                                       plot_diameter=plot_diameter,
                                                       max_depth=octree_depth,
                                                       verbose=False)

print(
    f"For a plot with a radius of {plot_diameter / 2}m, and octree depth of {octree_depth} the smallest octant size is {small_octant_res}m")

# Plot octree
octree = o3d.geometry.Octree(max_depth=octree_depth)
octree.convert_from_point_cloud(pcd, size_expand=0.01)

# Define camera parameters
lookat = octree.get_center()
up = [0.5, 0.5, 0.7]  # Up direction along negative y-axis
front = [-0.3, -0.3, 0.3]  # Front direction along negative z-axis
zoom = 1.0  # Adjust as needed

# Plot with specified camera view
o3d.visualization.draw_geometries([octree], lookat=lookat, up=up, front=front, zoom=zoom)

if name == "main":

# Load a point cloud
original_las_fpath = r"D:\Sync\RQ2\Analysis\data\new_brunswick\lidar\NBGOV11899.las"
points = read_las_to_np(las_fpath=original_las_fpath, use_ground_points=True)

# Try using jet colormap
cmap = get_custom_cmap('jet')

plot_octree(points=points, cmap=cmap, octree_depth=6)
jotix16 commented 3 months ago

I'm sorry for not getting back to you sooner. The functions should be self-explanatory. Let me know if you have any questions.


import torch
import torch.nn
from typing import Dict, TypedDict
import numpy as np

import ocnn
from ocnn.octree import Octree, Points
import open3d as o3d

def get_position_from_xyz_coords(octree: ocnn.octree.Octree, depth):
    x,y,z, _ = octree.xyzb(depth, nempty=True)
    size = 1/ (2**depth) * 2
    centers = (torch.stack([x, y, z], dim=1).cpu().numpy()+ 0.5) * size - 1.0  # [-1, 1]
    return centers, size

def get_lines_from_node(centers, size):
    """
        centers: (n_nodes, 3)
        size: scalar
    """
    corners_offsets = np.array([
        [[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1], [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]]
    ]) # (1, 8, 3) corners of the cube
    edges_offsets = np.array([
        [0, 1], [1, 2], [2, 3], [3, 0],
        [4, 5], [5, 6], [6, 7], [7, 4],
        [0, 4], [1, 5], [2, 6], [3, 7]
    ])  # (12, 2) edges of the cube, each row is a pair of indices of the corners that form an edge

    half_size = size / 2
    corners = centers[:, None, :] + half_size * corners_offsets

    edges = np.stack((corners[:, edges_offsets[:, 0], :],
                     corners[:, edges_offsets[:, 1], :]), axis=-2)  # (n_nodes, 12, 2, 3)

    return edges.reshape(-1, 2, 3)  # (n_lines, 2, 3)

def plot_octree_from_points(points):
    xyz, xyz_mean, scale = transform2origin(points)  # scale the points to [-1, 1]
    # Build octree
    points = Points(points=xyz, features=None, labels=None, batch_size=1)  # batch_size=1 means that we have one set of points at a time
    octree = ocnn.octree.Octree(depth=8, full_depth=3, batch_size=1, device=xyz.device) # batch_size=1 means that we have one octree at a time
    octree.build_octree(points)
    octree.construct_all_neigh()

    # plot with open3d
    all_lines = np.concatenate([get_lines_from_node(*get_position_from_xyz_coords(octree, d)) for d in range(0,9)])

    line_set = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(all_lines.reshape(-1, 3)),
                                    lines=o3d.utility.Vector2iVector(np.arange(all_lines.reshape(-1, 3).shape[0]).reshape(-1, 2)))

    o3d_pcd = o3d.geometry.PointCloud()
    o3d_pcd.points = o3d.utility.Vector3dVector(xyz.detach().cpu().numpy())
    o3d.visualization.draw_geometries([o3d_pcd, line_set])

def plot_octree(octree: ocnn.octree.Octree):
    all_lines = np.concatenate([get_lines_from_node(*get_position_from_xyz_coords(octree, d))
                                for d in range(octree.full_depth, octree.depth+1)])

    line_set = o3d.geometry.LineSet(points=o3d.utility.Vector3dVector(all_lines.reshape(-1, 3)),
                                    lines=o3d.utility.Vector2iVector(np.arange(all_lines.reshape(-1, 3).shape[0]).reshape(-1, 2)))

    o3d_pcd = o3d.geometry.PointCloud()
    o3d_pcd.points = o3d.utility.Vector3dVector(octree.get_input_feature(feature='P', nempty=False).detach().cpu().numpy())
    o3d.visualization.draw_geometries([o3d_pcd, line_set])

def transform2origin(xyz):
    # scale the points to [-1, 1]
    # works for both single and batched inputs
    min_pos = torch.min(xyz, -2, keepdim=True)[0]  # -2 so that it works for batched inputs
    max_pos = torch.max(xyz, -2, keepdim=True)[0]

    center = 0.5*(min_pos + max_pos)
    scale = torch.max(max_pos - min_pos, dim=-1, keepdim=True)[0] / 2.0
    new_position_tensor = (xyz - center) / scale

    return new_position_tensor, center, scale

def transformback(xyz, center, scale):
    return xyz * scale + center

def process_batch(xyz: torch.Tensor, features: torch.Tensor, depth: int, full_depth: int, feat: str, nempty: bool, normalize=False):
    """
    Process both single and batched inputs of xyz ([B], N , 3) and features ([B], N, F)
    into octree and query_pts.

    The purpose of octree: is to be able to encode the input points into a structure
    which keeps track of the spatial relationships between the points and allows for
    ussage of convolutions and other operations(attention/transformer) on the points.

    After we process  the octree node features through the network, we can get back
    a feature for each point by interpolating the octree node features at the query points.

    depth: int -- depth of the octree
    full_depth: int -- full depth of the octree (depth where the octree is full, i.e. all nodes have 8 children)

    query_pts: [B*N, 4] (normalized points [-1, 1] with batch index)
    """
    # scale the points to [-1, 1]
    if normalize:
        normalized_xyz, mean_xyz, scale_xyz = transform2origin(xyz)
    else:
        normalized_xyz = xyz

    if normalized_xyz.dim() == 2:
        # xyz: (N, 3) -- single
        points = Points(points=normalized_xyz, features=features, batch_size=1)
        query_pts = torch.cat([points.points, torch.zeros(normalized_xyz.shape[0], 1, device="cuda")], dim=1)
        B = 1
    else:
        # xyz: (B, N, 3) -- batched
        B, N, F = features.shape
        batch_ids = torch.arange(B, device=normalized_xyz.device).reshape(B, 1).repeat(1, N)  # (B, N)
        points = Points(points=normalized_xyz.reshape(B*N, 3), features=features.reshape(B*N, F), batch_id=batch_ids.reshape(B*N), batch_size=B)
        query_pts = torch.cat([points.points, points.batch_id.unsqueeze(-1)], dim=1)

    octree = ocnn.octree.Octree(depth=depth, full_depth=full_depth, batch_size=B, device=normalized_xyz.device)
    octree.build_octree(points)
    octree.construct_all_neigh()

    data = octree.get_input_feature(feat, nempty)  # get the feature on the octree nodes

    return octree, data, query_pts