octree-nn / ocnn-pytorch

Octree-based Sparse Convolutional Neural Networks
MIT License
150 stars 16 forks source link

Is it possible to mix nempty=False convolutions with nempty=True convolutions? #29

Closed ErenBalatkan closed 5 months ago

ErenBalatkan commented 5 months ago

Greetings,

Thanks for the incredible work!

I'm experimenting with a setting where I would like to first downscale input voxel grid to /8 and then perform shape reconstruction from that point. For the encoding layer, I would prefer if network behaved like a Sparse net (nempty=True) but during upscaling, since I want it to perform reconstruction it should perform convolution on non-empty as well (nempty=False).

I have tried to test this approach with a simple script, but it is giving me shape error.

if __name__ == '__main__':
    conv = ocnn.nn.octree_conv.OctreeConv(3, 3, [3], 2, True, False, False)
    upconv = ocnn.nn.octree_conv.OctreeDeconv(3, 3, [2], 2, False)

    small_tree = ocnn.octree.Octree(3, full_depth=1)
    point = [0.2, 0.5, 0.5]
    point = torch.tensor(point).unsqueeze(0)
    small_tree.build_octree(ocnn.octree.points.Points(point))
    small_tree.construct_all_neigh()
    print("Octree shape:", octree2voxel(small_tree.get_input_feature(), small_tree, 3, True)[0].shape)
    encoding = ocnn.nn.octree_max_pool(conv(small_tree.get_input_feature(), small_tree, 3), small_tree, 3, nempty=True)
    print("Octree shape:", octree2voxel(encoding, small_tree, 2, True)[0].shape)

    # OK to this point

    reconstruction = upconv(encoding, small_tree, 2)  # Crash
wang-ps commented 5 months ago

the network behaved like a Sparse net if (nempty=True). It should perform convolution on non-empty as well (nempty=False) for reconstruction.

Yes, your understanding and usage is correct. If (nempty=True), ocnn behaves exactly the same with other Sparse nets which use Hash Tables. If nempty=True, the tensor contains only features in non-empty octree nodes; otherwise, the tensor contains features in both non-empty and empty octree nodes. The size of the tensors is different. This is the reason for the error.

ErenBalatkan commented 5 months ago

Thanks for the answer!