bethgelab / foolbox

A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
https://foolbox.jonasrauber.de
MIT License
2.73k stars 425 forks source link

Is there a way to change/override the expected loss function of Foolbox's 'attack' function? #689

Open John-Chin opened 2 years ago

John-Chin commented 2 years ago

Hi, I am currently trying to generate adversarial images of a CIFAR10 dataset that can fool a CNN in a simple color estimation task.

As a little background for this color estimation task, I created a CNN to estimate the average global color of an image as a 3d column vector in terms of [average red channel, average blue channel, average green channel]. To train this CNN, I used a relabeled CIFAR10 dataset (where each label is no longer a number denoting it's class but a 3d vector of the average color) and MSE as my loss function.

But now, when I tried to use foolbox's attack function raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=epsilons) to create an adversarial image, I ran into this error:

Screen Shot 2022-05-30 at 7 31 55 PM

From this error trace, it seems that the attack function is automatically geared toward object recognition models that use cross entropy loss and output a 1d label.

So my question is: is there any way to modify the attack function's expected loss function to be Mean Square Error instead of Cross Entropy? Or perhaps I should not use the built-in attack function, and instead generate the adversarial images manually? Thank you for any thoughts or tips.

zimmerrol commented 2 years ago

Hi. Usually, one considers tha combination of loss function and optimization algorithm as the adversarial attack. Thus, changing the loss function of an attack gives you a new adversarial attack. Therefore, in foolbox, you cannot change the loss function of (most) attacks arbitrarily. However, you can always create a new class, inherit from whatever attack you want to modify and then overwrite the definition of the loss function. Does that make sense to you?