DSE-MSU / DeepRobust

A pytorch adversarial library for attack and defense methods on images and graphs
MIT License
994 stars 192 forks source link

Graph global attack API #19

Closed henrykenlay closed 2 years ago

henrykenlay commented 4 years ago

I've been using this library to generate global adversarial attacks on graphs using DICE and PGD. I have some suggestions which I'd be happy to implement and put a pull request in for, but I wanted to raise it for your feedback first and check to see if this is something you would be interested in me doing.

Here's a code snippet where I do a DICE and PGD attack I will use to highlight the motivation for my suggestions.

from deeprobust.graph.data import Dataset
from deeprobust.graph.global_attack import DICE, PGDAttack
from deeprobust.graph.defense import GCN
import numpy as np
import torch
import scipy.sparse as sp

# parameters
perturbations = 20
device = 'cpu'

# get data
data = Dataset(root='/tmp', name='cora', setting='nettack')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
idx_unlabeled = np.union1d(idx_val, idx_test)
features = torch.FloatTensor(features.todense())
labels = torch.LongTensor(labels)

# fit model
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16, device=device)
surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train, idx_val)

# dice attack
model = DICE(model=surrogate, nnodes=adj.shape[0], device=device)
model = model.to(device)
modified_adj = model.attack(adj=adj, labels=labels, n_perturbations=perturbations)
modified_adj = modified_adj

# pgd attack
model = PGDAttack(model=surrogate, nnodes=adj.shape[0], device=device)
model = model.to(device)
model.attack(ori_features=features, ori_adj=torch.FloatTensor(adj.todense()), labels=labels, idx_train=idx_train, perturbations=perturbations)
modified_adj = sp.csr_matrix(model.modified_adj.cpu().numpy())

The DICE and PGDAttack classes both subclass BaseAttack but have slightly different ways to use the attack method.

I think the following changes could benefit the deeprobust/graph/global_attack API

Let me know your feedback on these suggestions, I can make them into separate issues which will help us track progress. If you agree with some of these changes I'll be happy to get started on them.

ChandlerBang commented 4 years ago

Hi Henry,

Thanks for your valuable suggestions!

For the following comments, it would be really great if you can work on them. If you feel the workload is heavy, you can choose one or two to work and I will work on others.

  • Giving BaseAttack.attack an argument list and agree on what input types the methods will be able to accept
  • Use Errors instead of assert statements, for example in the PGDAttack.init method
  • Replace BaseAttack.save_adj and BaseAttack.save_features with methods that return these variables, rather than save them.
  • Add an epochs parameter to PGDAttack.attack
  • Including a **kwargs keyword into attack methods that don't use all arguments. This would allow one to change the attack method with very little effort.

For this one,

Using a docstring style such as the numpy/scipy or Google style which gives argument and return descriptions and types

it is also our next step so we will start to work on it recently.

Thank you again for your help!

henrykenlay commented 4 years ago

Hi Wei,

I should be able to manage these tasks, but I'll raise an issue for each so we can track progress and share the workload if needed. Of the five points you've listed above, I will do bullet point 4 as a pull request later because it's just a few lines. The first and fifth bullet point I'll raise together as an issue since its a general refactor to make the classes closer to the base class. The 3rd bullet shouldn't be too much work once the classes are more consistent. Then we can work on the 2nd which may take more work.

Look forward to collaborating!

ChandlerBang commented 4 years ago

Great. Thank you : )