I would like to write a custom architecture. I'm using the layer OctreeAvgPool(kernel_size=[3, 3, 3], stride=2).
I also wrote a function to return a view by batch, i.e., [batch_size, NPoints, Features]
The function is defined as:
from torch.nn.utils.rnn import pad_sequence
# The function returns the data in the octree and pads the sequence in case the Number of points per octree are different.
def to_batch_view(self, data, octree, depth):
batch_id = octree.batch_id(depth)
out = []
for i in range(octree.batch_size):
idx = torch.argwhere(batch_id == i).squeeze()
out.append(data[idx, :])
return pad_sequence(out, batch_first=True, padding_value=0.0)
The shape of the batch_id does not correspond to the get_neigh at the previous level, i.e.,
depth: 6
neigh shape with stride 2 torch.Size([49152, 27])
neigh shape with stride 1 torch.Size([393216, 27])
batch_id shape torch.Size([393216])
depth: 5
neigh shape with stride 2 torch.Size([6144, 27])
neigh shape with stride 1 torch.Size([49152, 27])
batch_id shape torch.Size([49152])
depth: 4
neigh shape with stride 2 torch.Size([768, 27])
neigh shape with stride 1 torch.Size([6144, 27])
batch_id shape torch.Size([6144])
depth: 3
neigh shape with stride 2 torch.Size([96, 27])
neigh shape with stride 1 torch.Size([768, 27])
batch_id shape torch.Size([768])
depth: 2
neigh shape with stride 2 torch.Size([24, 27])
neigh shape with stride 1 torch.Size([192, 27])
batch_id shape torch.Size([192])
depth: 1
neigh shape with stride 2 torch.Size([3, 27])
neigh shape with stride 1 torch.Size([24, 27])
batch_id shape torch.Size([24])
neigh shape with stride 2 at depth 6 = 49152 == batch_id shape at depth 5. However, neigh shape with stride 2 at depth 3 = 96 != batch_id shape at depth 2.
Is there a preferred/correct way to get the corresponding batch_id for the elements?
I would like to write a custom architecture. I'm using the layer OctreeAvgPool(kernel_size=[3, 3, 3], stride=2).
I also wrote a function to return a view by batch, i.e., [batch_size, NPoints, Features] The function is defined as:
The shape of the batch_id does not correspond to the get_neigh at the previous level, i.e.,
neigh shape with stride 2 at depth 6 = 49152 == batch_id shape at depth 5. However, neigh shape with stride 2 at depth 3 = 96 != batch_id shape at depth 2.
Is there a preferred/correct way to get the corresponding batch_id for the elements?