octree-nn / ocnn-pytorch

Octree-based Sparse Convolutional Neural Networks
MIT License
150 stars 16 forks source link

How to get batch_id when using average pool operation with stride 2 at depth 3 #37

Closed juanprietob closed 3 weeks ago

juanprietob commented 1 month ago

I would like to write a custom architecture. I'm using the layer OctreeAvgPool(kernel_size=[3, 3, 3], stride=2).

I also wrote a function to return a view by batch, i.e., [batch_size, NPoints, Features] The function is defined as:

from torch.nn.utils.rnn import pad_sequence

# The function returns the data in the octree and pads the sequence in case the Number of points per octree are different. 
def to_batch_view(self, data, octree, depth):
        batch_id = octree.batch_id(depth)        
        out = []
        for i in range(octree.batch_size):        
            idx = torch.argwhere(batch_id == i).squeeze()
            out.append(data[idx, :])
        return pad_sequence(out, batch_first=True, padding_value=0.0)

The shape of the batch_id does not correspond to the get_neigh at the previous level, i.e.,

depth: 6
neigh shape with stride 2 torch.Size([49152, 27])
neigh shape with stride 1 torch.Size([393216, 27])
batch_id shape torch.Size([393216])

depth: 5
neigh shape with stride 2 torch.Size([6144, 27])
neigh shape with stride 1 torch.Size([49152, 27])
batch_id shape torch.Size([49152])

depth: 4
neigh shape with stride 2 torch.Size([768, 27])
neigh shape with stride 1 torch.Size([6144, 27])
batch_id shape torch.Size([6144])

depth: 3
neigh shape with stride 2 torch.Size([96, 27])
neigh shape with stride 1 torch.Size([768, 27])
batch_id shape torch.Size([768])

depth: 2
neigh shape with stride 2 torch.Size([24, 27])
neigh shape with stride 1 torch.Size([192, 27])
batch_id shape torch.Size([192])

depth: 1
neigh shape with stride 2 torch.Size([3, 27])
neigh shape with stride 1 torch.Size([24, 27])
batch_id shape torch.Size([24])

neigh shape with stride 2 at depth 6 = 49152 == batch_id shape at depth 5. However, neigh shape with stride 2 at depth 3 = 96 != batch_id shape at depth 2.

Is there a preferred/correct way to get the corresponding batch_id for the elements?

wang-ps commented 1 month ago

for small octree depth, the octree is full, and the node numbers are the same across the batch. You can directly construct the batch_id.