qitianwu / NodeFormer

The official implementation of NeurIPS22 spotlight paper "NodeFormer: A Scalable Graph Structure Learning Transformer for Node Classification"
286 stars 27 forks source link

BUG: x_i doesn't match edge_index_i in main-batch.py #7

Closed ObsisMc closed 1 year ago

ObsisMc commented 1 year ago

Description

Maybe there is a bug in main-batch.py because of random permutation of index: https://github.com/qitianwu/NodeFormer/blob/64d26581f571340ab750ce6e60a8bb524e22e726/main-batch.py#L133-L144

In line 135, train_idx is sorted but idx_i is shuffled, then in line 136 x_i is also randomly permuted which means the original order of node has been changed. However in line 138, node idx in adjs[0] is still in order and subgraph() also remain the order. In this way, in line 144, x_i doesn't align with adjs_i.

Then, I change the code like this way so that the node idx can keep the original order:

idx = torch.randperm(train_idx.size(0))
for i in range(num_batch):
    idx_i = train_idx[idx[i * args.batch_size:(i + 1) * args.batch_size]]
    x_i = x[idx_i].to(device)
    adjs_i = []
    edge_index_i, _ = subgraph(idx_i, adjs[0], num_nodes=n, relabel_nodes=True)

    # Modify
    idx_perm = torch.argsort(idx_i)
    edge_index_i = idx_perm[edge_index_i]

    adjs_i.append(edge_index_i.to(device))
    for k in range(args.rb_order - 1):
        edge_index_i, _ = subgraph(idx_i, adjs[k + 1], num_nodes=n, relabel_nodes=True)
        adjs_i.append(edge_index_i.to(device))
    optimizer.zero_grad()
    out_i, link_loss_ = model(x_i, adjs_i, args.tau)

Experiments

python main-batch.py --dataset ogbn-arxiv --metric acc --method nodeformer --lr 1e-2 --weight_decay 0. --num_layers 3 --hidden_channels 64 --num_heads 1 --rb_order 1 --rb_trans identity --lamda 0.1 --M 50 --K 5 --use_bn --use_residual --use_gumbel --use_act --use_jk --batch_size 20000 --runs 1 --epochs 1000 --eval_step 9 --device 0

Before modification, test acc in ogbn-arxiv is only about 55%. After modification, test acc of it can be over 65%.

qitianwu commented 1 year ago

Hi Ruihao,

Thank you for carefully checking our codes and proposing this issue. However, I think our implementation is correct since for the subgraph function we set relabel_nodes=True which will re-order the idx_i from zero to idx_i.shape[0] in the returned edge_index. Therefore, the node index of the used edge_index is consistent with the x_i.

In terms of the experiments on obgn-arxiv, thank you for running our model on this dataset that we have not tried before. The 65% test acc seems not high for obgn-arxiv, so the performance impr here might not stem from the modified ordering part, but possibly the hyper-parameter settings in which NodeFormer does not perform well enough

ObsisMc commented 1 year ago

Thank you for your reply, and I find it is the version of PyG that causes the problem. My version is 2.1.0 and yours is 1.7.2. relabel_nodes in 2.1.0 doesn't re-order index but in 1.7.2 does.

By the way, because of the problem, my training result in ogbn-proteins is also about 72% like issue #3 and after fixing it, it can get around 77% like your paper does. Maybe issue #3 also met this problem.

qitianwu commented 1 year ago

I see. The PyG version indeed matters in practice. Thanks a lot for letting me know about this