octree-nn / ocnn-pytorch

Octree-based Sparse Convolutional Neural Networks
MIT License
150 stars 16 forks source link

Batched points to octree vs merge_octrees #36

Closed jotix16 closed 4 months ago

jotix16 commented 4 months ago

Summary

I was comparing the two main ways to handle batched points (creating batched points or merging octrees of different points) and realized a peculiarity where the members octre.nnum and octree.nnum_empty change from int32 to int64.

Question

It looks that the two methods are equivalent value-wise. Would the long type make a difference on the bit-operations?

Why it happens

https://github.com/octree-nn/ocnn-pytorch/blob/fb0dd36cb64bdbcd92989031d85d92effe4b3726/ocnn/octree/octree.py#L578-L579

Reproduce

import torch
import ocnn
from ocnn.octree import Points

def process_batch(batched_points, batched_features):
    def points2octree(points):
        octree = ocnn.octree.Octree(5, 2, batch_size=1, device=points.device)
        octree.build_octree(points)
        return octree

    # batched_points: (B, N, 3)
    B = batched_points.shape[0]
    batch_ids = torch.arange(B, device=batched_points.device).view(-1, 1).repeat(1, batched_points.shape[1])  # (B, N)
    list_points = [Points(points=batched_points[i], features=batched_features[i], batch_id=batch_ids[i], batch_size=1) for i in range(B)]
    points = [pts.cuda(non_blocking=True) for pts in list_points]
    octrees = [points2octree(pts) for pts in points]
    octree = ocnn.octree.merge_octrees(octrees)
    octree.construct_all_neigh()
    new_points = ocnn.octree.merge_points(points)
    return octree, new_points

def process_batch2(batched_xyz, batched_features):
    # batched_points: (B, N, 3)
    B = batched_xyz.shape[0]
    batch_ids = torch.arange(B, device=batched_xyz.device).view(-1, 1).repeat(1, batched_xyz.shape[1])  # (B, N)
    batched_points = Points(points=batched_xyz.reshape(-1, 3), features=batched_features.reshape(-1, 3), batch_id=batch_ids.view(-1), batch_size=B)
    octree = ocnn.octree.Octree(5, 2, batch_size=B, device=batched_xyz.device)
    octree.build_octree(batched_points)
    octree.construct_all_neigh()
    return octree, batched_points

al = torch.randn(4, 100, 3).cuda()
octree1, new_pts1 = process_batch(al, al)
octree2, new_pts2 = process_batch2(al, al)

print(torch.allclose(new_pts2.points, new_pts1.points))
print([torch.allclose(c1, c2) for c1, c2 in zip(octree1.children, octree2.children)])
print([torch.allclose(c1, c2) for c1, c2 in zip(octree1.keys, octree2.keys)])

print([torch.allclose(c1, c2) for c1, c2 in zip(octree1.points, octree2.points) if c1 is not None])
print([torch.allclose(c1, c2) for c1, c2 in zip(octree1.features, octree2.features) if c1 is not None])
print([torch.allclose(c1, c2) for c1, c2 in zip(octree1.nnum.to(dtype=torch.int32), octree2.nnum)])
print([torch.allclose(c1, c2) for c1, c2 in zip(octree1.nnum_nempty.to(dtype=torch.int32), octree2.nnum_nempty)])
print(octree1.nnum_nempty)
print(octree2.nnum_nempty)

torch.sum(octree2.nnum_nempty, dim=0).dtype # !!!

Output

True
[True, True, True, True, True, True]
[True, True, True, True, True, True]
[True]
[True]
[True, True, True, True, True, True]
[True, True, True, True, True, True]
tensor([  4,  32, 190, 359, 396, 400])
tensor([  4,  32, 190, 359, 396, 400], dtype=torch.int32)
wang-ps commented 4 months ago

Thank you for pointing out this issue! I carefully read your comments and checked the code. Yes, you find the reason for this phenomenon. I will fix it in the next commits.

Would the long type make a difference on the bit-operations? nnum_nempty is the non-empty node number. Sorry, I am not fully sure what do you mean by bit-opeartions. The code works fine currently no matter it is long or int32.

jotix16 commented 4 months ago

Hi @wang-ps :),

thank you for the response.

Does the ocnn library work with batched inputs?

EDIT 2

Answer: YES (issue can be closed)

Solution: If you decide to use BatchNorm layers, make sure to call model.eval()or layer.eval() during testing time. In the code below, conv1.eval() solves the issue.


Somehow, I cannot manage to get the same results for batched points and a sequence of not batched points.

Reproduce

import torch
import ocnn
from ocnn.octree import Octree, Points

nempty = True
depth = 4
full_depth = 2
conv1 = ocnn.modules.OctreeConvBnRelu(3, 3, nempty=nempty).cuda()
interp = ocnn.nn.OctreeInterp(method='linear', nempty=nempty, bound_check=True, rescale_pts=True) # data is the feature on the octree nodes to be interpolated | pts is the points to interpolate

def process_batch(normalized_xyz: torch.Tensor, features: torch.Tensor, depth: int, full_depth: int, feat: str, nempty: bool):
    """
    Process both single and batched inputs of xyz ([B], N , 3) and features ([B], N, F)
    into octree, data(node-features) and query_pts.
    After we process  the octree node features through the network, we can get back
    a feature for each point by interpolating the octree node features at the query points.

    query_pts: [B*N, 4] (normalized points [-1, 1] with batch index)
    """
    # scale the points to [-1, 1]

    if normalized_xyz.dim() == 2:
        # xyz: (N, 3) -- single
        points = Points(points=normalized_xyz, features=features, batch_size=1)
        query_pts = torch.cat([points.points, torch.zeros(normalized_xyz.shape[0], 1, device="cuda")], dim=1)
        B = 1
    else:
        # xyz: (B, N, 3) -- batched
        B, N, F = features.shape
        batch_ids = torch.arange(B, device=normalized_xyz.device).reshape(B, 1).repeat(1, N)  # (B, N)
        points = Points(points=normalized_xyz.reshape(B*N, 3), features=features.reshape(B*N, F), batch_id=batch_ids.reshape(B*N), batch_size=B)
        query_pts = torch.cat([points.points, points.batch_id.unsqueeze(-1)], dim=1)

    octree = ocnn.octree.Octree(depth=depth, full_depth=full_depth, batch_size=B, device=normalized_xyz.device)
    octree.build_octree(points)
    octree.construct_all_neigh()

    data = octree.get_input_feature(feat, nempty)  # get the feature on the octree nodes

    return octree, data, query_pts

def call(in_points):
    octree, data, query_pts = process_batch(in_points, in_points, depth, full_depth,'F', nempty)
    out = conv1(data, octree, depth)

    points_b = interp(out, octree, depth=octree.depth, pts=query_pts)
    return points_b

input_points = torch.rand(3, 32, 3).cuda()

out_batched_3 = call(input_points).reshape(3, 32, 3)
out_batched_2 = call(input_points[0:2]).reshape(2, 32, 3)
out_non_batched = call(input_points[0])

print(out_batched_3.shape, out_batched_2.shape, out_non_batched.shape)

print(torch.allclose(out_batched_3[:2], out_batched_2))
print(torch.allclose(out_batched_3[0], out_non_batched))
# out_non_batched - out_batched_3[0]

Output

torch.Size([3, 32, 3]) torch.Size([2, 32, 3]) torch.Size([32, 3])
False
False

The problem arises during the convolution [EDIT!]

If I only run process_batch and interp I get the same values, which only leaves the conv-layer as problem source.

octree_b, data_b, query_pts_b = process_batch(input_points, input_points, depth, full_depth,'F', nempty)
octree_s, data_s, query_pts_s = process_batch(input_points[0], input_points[0], depth, full_depth,'F', nempty)

points_b = interp(data_b, octree_b, depth=octree_b.depth, pts=query_pts_b).reshape(3, 32, 3)
points_s = interp(data_s, octree_s, depth=octree_s.depth, pts=query_pts_s).reshape(-1, 3)

print(points_b.shape, points_s.shape)
print(torch.allclose(points_b[0], points_s))

Output

torch.Size([3, 32, 3]) torch.Size([32, 3])
True
wang-ps commented 4 months ago

Thank you for your feedback!