Closed ArkadiyD closed 1 year ago
Hi,
I have no idea why it happens. But if we modify line 118 at layers.py
as follows:
out = self.propagate(hyperedge_index, x=out, norm=D, size=(num_nodes, num_edges))
it works again (I do not check carefully why it can work). While it can run without raising error, I observe a performance drop. To fully reproduce our results, I would like to recommend you using the same PyG version as we list in requirements
.
Thanks for the quick response! One strange detail I notice is the following:
ratio1 = 0.7
. After first Hyperedge Pooling (pool1) the shape is changed as follows: f1_conv1: torch.Size([2614, 100]) becomes edge1_conv1: torch.Size([1842, 100]). As far as I understand this is the expected behaviour as the number of nodes is lower after 0.7 top hyperedges are selected.
ratio1 = 1.0
. After first Hyperedge Pooling (pool1) the shape is changed as follows: f1_conv1: torch.Size([2614, 100]) becomes edge1_conv1: torch.Size([32, 100]). However, as far as I understand ratio 1.0 means no hyperedges are dropped. What's the intuition behind this behaviour?
Thanks!
Upd: I think the reason is that the topk
function (link) doesn't work correctly when ratio is 1 because when it's integer it's treated not as ratio, but as the exact number to keep.
Hi,
This is due to the update of topk
function. Please refer here for more detail. And topk
does not exist in topk_pool.py
anymore.
You can use the following code for topk
function:
def topk(x, ratio, batch, min_score=None, tol=1e-7):
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter_max(x, batch)[0][batch] - tol
scores_min = scores_max.clamp(max=min_score)
perm = (x > scores_min).nonzero(as_tuple=False).view(-1)
else:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes, ),
torch.finfo(x.dtype).min)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
_, perm = dense_x.sort(dim=-1, descending=True)
perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
perm = perm[mask]
return perm
This is the code from an older version of PyG.
I still recommend you to use the older version of PyG to reproduce our results.
Thanks.
That's clear now, thanks!
I'm trying to run the code (
python main_classification.py --dataset openssl_min50
) but running into an errorat line
coming from the line
I have torch_geometric 2.2.0 and it would be helpful to make it work for the current versions of PyG. Thanks! P.S. there are also problems with data processing but I've managed to fix them following https://stackoverflow.com/a/72032357