Closed MichaleWong closed 2 years ago
Generally speaking, getting cluster
is the hard part :)
Note that the cluster
vector in max_pool_x
only assigns a single cluster to each node (and I am not sure if that's a good choice here, since then your cluster
vector depends on the order of nodes).
A faster implementation may look like this:
cluster = torch.full((x.size(0), ), -1, dtype=torch.long)
cluster[fps_idx] = torch.arange(fps_idx.size(0))
cluster.scatter_(0, row, cluster[col])
cluster[fps_idx] = torch.arange(fps_idx.size(0))
max_pool_x(...)
This implementation assigns clusters to neighbors in a greedy fashion.
If you want to allow multiple clusters per node, implementation is also straightforward (utilizing torch-sparse
):
row, col = edge_index
adj_t = SparseTensor(row=row, col=col, sparse_sizes=(x.size(0), x.size(0))).t()
adj_t = adj_t[fps_idx]
x = adj_t.matmul(x, reduce='max')
❓ Questions & Help
I am tring to implement a graph pooling that is similary to GACNet, I used max_pool_x(),but the parameters cluster vector, I cannot get it efficiently?
I use fps() down sampled the point cloud, and radius() to get each fps point edge_index, try to get cluster vector from the edge_index, but that is very slow, I think it's not a correct way to get cluster vector, do you have some suggestions?
how to get the cluster vector correctly?