bethgelab / foolbox

A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
https://foolbox.jonasrauber.de
MIT License
2.75k stars 425 forks source link

Batch support for PGD? #636

Closed minkyu-choi04 closed 3 years ago

minkyu-choi04 commented 3 years ago

In foolbox/foolbox/attacks/gradient_descent_base.py, class BaseGradientDescent(FixedEpsilonAttack, ABC):, there is get_loss_fn function.

def get_loss_fn(
    self, model: Model, labels: ep.Tensor
) -> Callable[[ep.Tensor], ep.Tensor]:
    # can be overridden by users
    def loss_fn(inputs: ep.Tensor) -> ep.Tensor:
        logits = model(inputs)
        return ep.crossentropy(logits, labels).sum()

    return loss_fn

The function loss_fn returns ep.crossentropy(logits, labels).sum(). If this is the case, I wonder the PGD attack is affected by other images in the same mini-batch because of .sum(). I thought the adversarial images are all independent, but it seems .sum() makes the loss and gradients summed up across different images in the mini-batch.

I wonder if I am right.

xmodar commented 3 years ago

The images in the batch will not affect each other. However, you are absolutely right in raising this issue. The problem is that the loss will be multiplied by the batch_size. Usually, this affects the learning rate but it does nothing for the attacks since we normalize the gradients before using them (e.g. sign(grad) in LinfPGD). For numerical stability, we should use the mean instead of the sum. I wrote a simple test script that compares using half the batch and the full batch to demonstrate this:

import torch
import foolbox as fb
from torchvision.models import resnet18

model = fb.PyTorchModel(
    resnet18(pretrained=True).cuda().eval().requires_grad_(False),
    bounds=(0, 1),
    preprocessing=dict(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225],
                       axis=-3),
)

images, labels = fb.utils.samples(model, dataset='imagenet', batchsize=16)
attack = fb.attacks.PGD(rel_stepsize=0.1, steps=1, random_start=False)

out1 = attack(model, images, labels, epsilons=8 / 255)[0]

for batch_size in (8, 16):
    out2 = torch.cat([
        attack(model, m, l, epsilons=8 / 255)[0]
        for m, l in zip(images.split(batch_size), labels.split(batch_size))
    ])
    print(batch_size, (out1 - out2).abs().sum().item())  # error
minkyu-choi04 commented 3 years ago

Thank you for the detailed explanation. I am totally clear now.