Closed jotix16 closed 6 months ago
Thank you for pointing out this issue! I carefully read your comments and checked the code. Yes, you find the reason for this phenomenon. I will fix it in the next commits.
Would the long type make a difference on the bit-operations?
nnum_nempty
is the non-empty node number. Sorry, I am not fully sure what do you mean by bit-opeartions. The code works fine currently no matter it islong
orint32
.
Hi @wang-ps :),
thank you for the response.
Does the ocnn library work with batched inputs?
Answer: YES (issue can be closed)
Solution: If you decide to use BatchNorm
layers, make sure to call model.eval()
or layer.eval()
during testing time. In the code below, conv1.eval()
solves the issue.
Somehow, I cannot manage to get the same results for batched points and a sequence of not batched points.
import torch
import ocnn
from ocnn.octree import Octree, Points
nempty = True
depth = 4
full_depth = 2
conv1 = ocnn.modules.OctreeConvBnRelu(3, 3, nempty=nempty).cuda()
interp = ocnn.nn.OctreeInterp(method='linear', nempty=nempty, bound_check=True, rescale_pts=True) # data is the feature on the octree nodes to be interpolated | pts is the points to interpolate
def process_batch(normalized_xyz: torch.Tensor, features: torch.Tensor, depth: int, full_depth: int, feat: str, nempty: bool):
"""
Process both single and batched inputs of xyz ([B], N , 3) and features ([B], N, F)
into octree, data(node-features) and query_pts.
After we process the octree node features through the network, we can get back
a feature for each point by interpolating the octree node features at the query points.
query_pts: [B*N, 4] (normalized points [-1, 1] with batch index)
"""
# scale the points to [-1, 1]
if normalized_xyz.dim() == 2:
# xyz: (N, 3) -- single
points = Points(points=normalized_xyz, features=features, batch_size=1)
query_pts = torch.cat([points.points, torch.zeros(normalized_xyz.shape[0], 1, device="cuda")], dim=1)
B = 1
else:
# xyz: (B, N, 3) -- batched
B, N, F = features.shape
batch_ids = torch.arange(B, device=normalized_xyz.device).reshape(B, 1).repeat(1, N) # (B, N)
points = Points(points=normalized_xyz.reshape(B*N, 3), features=features.reshape(B*N, F), batch_id=batch_ids.reshape(B*N), batch_size=B)
query_pts = torch.cat([points.points, points.batch_id.unsqueeze(-1)], dim=1)
octree = ocnn.octree.Octree(depth=depth, full_depth=full_depth, batch_size=B, device=normalized_xyz.device)
octree.build_octree(points)
octree.construct_all_neigh()
data = octree.get_input_feature(feat, nempty) # get the feature on the octree nodes
return octree, data, query_pts
def call(in_points):
octree, data, query_pts = process_batch(in_points, in_points, depth, full_depth,'F', nempty)
out = conv1(data, octree, depth)
points_b = interp(out, octree, depth=octree.depth, pts=query_pts)
return points_b
input_points = torch.rand(3, 32, 3).cuda()
out_batched_3 = call(input_points).reshape(3, 32, 3)
out_batched_2 = call(input_points[0:2]).reshape(2, 32, 3)
out_non_batched = call(input_points[0])
print(out_batched_3.shape, out_batched_2.shape, out_non_batched.shape)
print(torch.allclose(out_batched_3[:2], out_batched_2))
print(torch.allclose(out_batched_3[0], out_non_batched))
# out_non_batched - out_batched_3[0]
torch.Size([3, 32, 3]) torch.Size([2, 32, 3]) torch.Size([32, 3])
False
False
If I only run process_batch
and interp
I get the same values, which only leaves the conv-layer as problem source.
octree_b, data_b, query_pts_b = process_batch(input_points, input_points, depth, full_depth,'F', nempty)
octree_s, data_s, query_pts_s = process_batch(input_points[0], input_points[0], depth, full_depth,'F', nempty)
points_b = interp(data_b, octree_b, depth=octree_b.depth, pts=query_pts_b).reshape(3, 32, 3)
points_s = interp(data_s, octree_s, depth=octree_s.depth, pts=query_pts_s).reshape(-1, 3)
print(points_b.shape, points_s.shape)
print(torch.allclose(points_b[0], points_s))
torch.Size([3, 32, 3]) torch.Size([32, 3])
True
Thank you for your feedback!
Summary
I was comparing the two main ways to handle batched points (creating batched points or merging octrees of different points) and realized a peculiarity where the members
octre.nnum
andoctree.nnum_empty
change fromint32
toint64
.Question
It looks that the two methods are equivalent value-wise. Would the long type make a difference on the bit-operations?
Why it happens
https://github.com/octree-nn/ocnn-pytorch/blob/fb0dd36cb64bdbcd92989031d85d92effe4b3726/ocnn/octree/octree.py#L578-L579
Reproduce
Output