bethgelab / foolbox

A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
https://foolbox.jonasrauber.de
MIT License
2.77k stars 426 forks source link

PyTorch batch_predictions doesn't work #307

Closed anianruoss closed 5 years ago

anianruoss commented 5 years ago

I have a custom PyTorch model (derived from nn.Module) and a batch of MNIST images (obtained via torch.utils.data.DataLoader) with which I want to perform an FGMS attack as follows:

foolbox_model = foolbox.models.PyTorchModel(
    model, bounds=(0, 1), num_classes=10
)
attack = foolbox.attacks.FGSM(foolbox_model)
perturbed_images = attack(images.numpy(), predicted_labels.numpy())

where the shape of images is (batch_size, 1, 28, 28).

Following the call graph:

I'm relatively new to foolbox so I don't know how to fix this, but it definitely seems wrong. An obvious solution would be to remove [np.newaxis] but then the attack for single-batch images would probably fail.

jonasrauber commented 5 years ago

You should pass a single image and label, not a batch. Batch support will come with Foolbox 2.0. You can already try a prototype by using this PR: https://github.com/bethgelab/foolbox/pull/295