GentleZhu / Shift-Robust-GNNs

"Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data" (NeurIPS 21')
47 stars 7 forks source link
distribution-shift graph-neural-networks graph-representation-learning network-embeddings

Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data

Source code for "Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data", published in NeurIPS 2021.

This project also provides pre-processed biased training samples for three small-size (Cora/Citeseer/PubMed) and two medium-size graph (Ogb-arxiv/Reddit) benchmarks. Besides, scalable biased sampler is a tool to create biased training sample for any large-size graphs.

If you find our paper useful, please consider cite the following paper.

@article{zhu2021shift,
  title={Shift-robust gnns: Overcoming the limitations of localized graph training data},
  author={Zhu, Qi and Ponomareva, Natalia and Han, Jiawei and Perozzi, Bryan},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}

Installation

pip3 install -r requirements.txt

Quick Start

Apply SRGNN and DAGNN

SRGNN is our method and DAGNN is our first reference implementation for domain adversarial neural network(DANN) in GNNs. There are two hyper-parameter in our shift-robust function. One is alpha for the regularization and beta as the minimal weight of instance re-weighting in Equation 10 of the main paper.

python toy_gnn.py

Please refer to toy_gnn.py for reference implementation of ours and DANN (more comparison in Table 4 of the main paper). We recommend to implement your new distributional-shift method as a classification head as what we did in ToyGNN.

Processed data

For three smaller GNN benchmarks, we provide 100 different biased training seeds under data/localized_seeds_{dataset}.p. We also provide the tool to load our indexed network and you can test the performance of your GNN using our biased seeds. For two medium size graphs, we provide 10 different biased training seeds as data/{dataset}_{label_ratio}.

We use pre-computed PPR vector in our model, which can be found at:

intermediate: https://drive.google.com/file/d/1zlJp9KEqxiApWX3IxC8nRx6tlkiqR-5V/view?usp=sharing

Please unzip the file into intermediate/.

Small Graph Benchmark

To reproduce the results of different GNN methods under biased training,

python main_gnn.py --n-epochs=200 --dataset=$DATASET --gnn-arch=$METHOD --n-repeats=100 --n-hidden=32 --dropout=0.5 --weight-decay=0.0005 --biased-sample=True --gpu=$GPU_ID --SR=$SR

Option of $Method can be one of graphsage, gat, ppnp, sgc, $DATASET can be one of cora/citeseer/pubmed, Shift-Robust is applied when $SR is True.

The mean accuracy (Mirco-F1) and standard deviation are summarized (with part of baselines) in the following table:

Dataset Cora Citeseer Pubmed
GCN 68.3 (4.3) 62.1 (1.4) 59.4 (4.5)
APPNP 71.2 (5.0) 63.7 (1.4) 65.8 (4.7)
SR-APPNP 74.0 (3.3) 67.2 (1.5) 70.5 (3.0)

The results in the table are sightly different from the paper, since we use a different random seed to generate biased training seeds. More comprehensive results can be found in Table 2 of the paper.

Larger Graph Experiments

We use GCN as base model in larger graph experiments as an example of applying SR-GNN on deep GNNs. The mean accuracy (Mirco-F1) and standard deviation are summarized (with part of baselines) in the following table: Running the follow command will first download the ogbn-arxiv and reddit dataset.

python main_gnn_large.py --n-epochs=100 --dataset=$DATASET --gnn-arch=gcn --n-repeats=10 --n-hidden=256 --dropout=0.5 --weight-decay=0.0005 --biased-sample=True --gpu=$GPU_ID --SR=$SR

Dataset Ogbn-Arxiv Reddit
GCN 59.3 (1.2) 89.6 (0.9)
SR-GCN 61.6 (0.6) 91.3 (0.5)