Closed xiaobaitu123344 closed 1 year ago
Isn't the partition-based pooling in your paper based on voxels?
No, the definition of partition-based pooling is: separating a point cloud into non-overlapping partitions, and fusion points share the same partition. For the implementation of grid pooling, we also compute which grid partitions each point belongs to and then fuse them. The implement code is attached below.
From my perspective, there is not much difference between voxel-based and point-based methods. Maybe voxels
is just a kind of point after grid sampling. and you might have also found that current point-based methods also apply voxelization
in data augmentation for downsampling points.
from torch_geometric.nn.pool import voxel_grid
from torch_scatter import segment_csr
class GridPool(nn.Module):
"""
Partition-based Pooling (Grid Pooling)
"""
def __init__(self,
in_channels,
out_channels,
grid_size,
bias=False):
super(GridPool, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grid_size = grid_size
self.fc = nn.Linear(in_channels, out_channels, bias=bias)
self.norm = PointBatchNorm(out_channels)
self.act = nn.ReLU(inplace=True)
def forward(self, points, start=None):
coord, feat, offset = points
batch = offset2batch(offset)
feat = self.act(self.norm(self.fc(feat)))
start = segment_csr(coord, torch.cat([batch.new_zeros(1), torch.cumsum(batch.bincount(), dim=0)]),
reduce="min") if start is None else start
cluster = voxel_grid(pos=coord - start[batch], size=self.grid_size, batch=batch, start=0)
unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True)
_, sorted_cluster_indices = torch.sort(cluster)
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
coord = segment_csr(coord[sorted_cluster_indices], idx_ptr, reduce="mean")
feat = segment_csr(feat[sorted_cluster_indices], idx_ptr, reduce="max")
batch = batch[idx_ptr[:-1]]
offset = batch2offset(batch)
return [coord, feat, offset], cluster
Thank you for your patience, I did not answer in time because of the epidemic this time, I would like to apologize for not answering questions in time
PTv2 is a point-based method. What's empty voxels meaning in your question?