This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper:
See slides here.
An earlier short version of our paper was presented as a contributed talk at ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019.
Update:
In the code for MNIST, the dist
variable should have been squared to make it a Gaussian. All figures and results were generated without squaring it. I don't think it's very important in terms of results, but if you square it, sigma
should be adjusted accordingly.
MNIST | TRIANGLES |
---|---|
For MNIST from top to bottom rows:
For TRIANGLES from top to bottom rows:
Note that during training, our MNIST models have not encountered noisy images and our TRIANGLES models have not encountered graphs larger than with N=25 nodes.
COLORS and TRIANGLES datasets are now also available in the TU format, so that you can use a general TU datareader. See PyTorch Geometric examples for COLORS and TRIANGLES.
For more examples, see MNIST_eval_models and TRIANGLES_eval_models.
# Download model checkpoint or 'git clone' this repo
import urllib.request
# Let's use the model with supervised attention (other models can be found in the Table below)
model_name = 'checkpoint_mnist-75sp_139255_epoch30_seed0000111.pth.tar'
model_url = 'https://github.com/bknyaz/graph_attention_pool/raw/master/checkpoints/%s' % model_name
model_path = 'checkpoints/%s' % model_name
urllib.request.urlretrieve(model_url, model_path)
# Load the model
import torch
from chebygin import ChebyGIN
state = torch.load(model_path)
args = state['args']
model = ChebyGIN(in_features=5, out_features=10, filters=args.filters, K=args.filter_scale,
n_hidden=args.n_hidden, aggregation=args.aggregation, dropout=args.dropout,
readout=args.readout, pool=args.pool, pool_arch=args.pool_arch)
model.load_state_dict(state['state_dict'])
model = model.eval()
# Load image using standard PyTorch Dataset
from torchvision import datasets
data = datasets.MNIST('./data', train=False, download=True)
images = (data.test_data.numpy() / 255.)
import numpy as np
img = images[0].astype(np.float32) # 28x28 MNIST image
# Extract superpixels and create node features
import scipy.ndimage
from skimage.segmentation import slic
from scipy.spatial.distance import cdist
# The number (n_segments) of superpixels returned by SLIC is usually smaller than requested, so we request more
superpixels = slic(img, n_segments=95, compactness=0.25, multichannel=False)
sp_indices = np.unique(superpixels)
n_sp = len(sp_indices) # should be 74 with these parameters of slic
sp_intensity = np.zeros((n_sp, 1), np.float32)
sp_coord = np.zeros((n_sp, 2), np.float32) # row, col
for seg in sp_indices:
mask = superpixels == seg
sp_intensity[seg] = np.mean(img[mask])
sp_coord[seg] = np.array(scipy.ndimage.measurements.center_of_mass(mask))
# The model is invariant to the order of nodes in a graph
# We can shuffle nodes and obtain exactly the same results
ind = np.random.permutation(n_sp)
sp_coord = sp_coord[ind]
sp_intensity = sp_intensity[ind]
# Create edges between nodes in the form of adjacency matrix
sp_coord = sp_coord / images.shape[1]
dist = cdist(sp_coord, sp_coord) # distance between all pairs of nodes
sigma = 0.1 * np.pi # width of a Guassian
A = np.exp(- dist / sigma ** 2) # transform distance to spatial closeness
A[np.diag_indices_from(A)] = 0 # remove self-loops
A = torch.from_numpy(A).float().unsqueeze(0)
# Prepare an input to the model and process it
N_nodes = sp_intensity.shape[0]
mask = torch.ones(1, N_nodes, dtype=torch.uint8)
# mean and std computed for superpixel features in the training set
mn = torch.tensor([0.11225057, 0.11225057, 0.11225057, 0.44206527, 0.43950436]).view(1, 1, -1)
sd = torch.tensor([0.2721889, 0.2721889, 0.2721889, 0.2987583, 0.30080357]).view(1, 1, -1)
node_features = (torch.from_numpy(np.pad(np.concatenate((sp_intensity, sp_coord), axis=1),
((0, 0), (2, 0)), 'edge')).unsqueeze(0) - mn) / sd
y, other_outputs = model([node_features, A, mask, None, {'N_nodes': torch.zeros(1, 1) + N_nodes}])
alpha = other_outputs['alpha'][0].data
y
is a vector with 10 unnormalized class scores. To get a predicted label, we can use torch.argmax(y)
.
alpha
is a vector of attention coefficients alpha for each node.
We design two synthetic graph tasks, COLORS and TRIANGLES, in which we predict the number of green nodes and the number of triangles respectively.
We also experiment with the MNIST image classification dataset, which we preprocess by extracting superpixels - a more natural way to feed images to a graph. We denote this dataset as MNIST-75sp.
We validate our weakly-supervised approach on three common graph classification benchmarks: COLLAB, PROTEINS and D&D.
For COLORS, TRIANGLES and MNIST we know ground truth attention for nodes, which allows us to study graph neural networks with attention in depth.
To generate all data using a single command: ./scripts/prepare_data.sh
.
All generated/downloaded ata will be stored in the local ./data
directory.
It can take about 1 hour to prepare all data (see my log) and all data take about 2 GB.
Alternatively, you can generate data for each task as described below.
In case of any issues with running these scripts, data can be downloaded from here.
To generate training, validation and test data for our Colors dataset with different dimensionalities:
for dim in 3 8 16 32; do python generate_data.py --dim $dim; done
To generate training and test data for our MNIST-75sp dataset using 4 CPU threads:
for split in train test; do python extract_superpixels.py -s $split -t 4; done
Once datasets are generated or downloaded, you can use the following IPython notebooks to load and visualize data:
COLORS and TRIANGLES, MNIST and COLLAB, PROTEINS and D&D.
Generalization results on the test sets for three tasks. Other results are available in the paper.
Click on the result to download a trained model in the PyTorch format.
Model | COLORS-Test-LargeC | TRIANGLES-Test-Large | MNIST-75sp-Test-Noisy | |
---|---|---|---|---|
Script to train models | colors.sh | triangles.sh | mnist_75sp.sh | |
Global pooling | 15 ± 7 | 30 ± 1 | 80 ± 12 | |
Unsupervised attention | 11 ± 6 | 26 ± 2 | 80 ± 23 | |
Supervised attention | 75 ± 17 | 48 ± 1 | 92.3 ± 0.4 | |
Weakly-supervised attention | 73 ± 14 | 30 ± 1 | 88.8 ± 4 |
The scripts to train the models must be run from the main directory, e.g.: ./scripts/mnist_75sp.sh
Examples of evaluating our trained models can be found in notebooks: MNIST_eval_models and TRIANGLES_eval_models.
To tune hyperparameters on the validation set for COLORS, TRIANGLES and MNIST, use the --validation
flag.
For COLLAB, PROTEINS and D&D tuning of hyperparameters is included in the training script. Use the --ax
flag.
Example of running 10 weakly-supervised experiments on PROTEINS with cross-validation of hyperparameters including initialization parameters (distribution and scale) of the attention model (the --tune_init
flag):
for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 25 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 100 -f 64,64,64 -K 1 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/PROTEINS --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None --tune_init | tee logs/proteins_wsup_"$dataseed"_"$seed".log; done; done
No initialization tuning on COLLAB:
for i in $(seq 1 1 10); do dataseed=$(( ( RANDOM % 10000 ) + 1 )); for j in $(seq 1 1 10); do seed=$(( ( RANDOM % 10000 ) + 1 )); python main.py --seed $seed -D TU --n_nodes 35 --epochs 50 --lr_decay_step 25,35,45 --test_batch_size 32 -f 64,64,64 -K 3 --readout max --dropout 0.1 --pool attn_sup_threshold_skip_skip_skip_0 --pool_arch fc_prev --results None --data_dir ./data/COLLAB --seed_data $dataseed --cv --cv_folds 5 --cv_threads 5 --ax --ax_trials 30 --scale None | tee logs/collab_wsup_"$dataseed"_"$seed".log; done; done
Note that results can be better if using --pool_arch gnn_prev
, but we didn't focus on that.
Python packages required (can be installed via pip or conda):
Please cite our paper if you use our data or code:
@inproceedings{knyazev2019understanding,
title={Understanding attention and generalization in graph neural networks},
author={Knyazev, Boris and Taylor, Graham W and Amer, Mohamed},
booktitle={Advances in Neural Information Processing Systems},
pages={4202--4212},
year={2019},
pdf={http://arxiv.org/abs/1905.02850}
}