fra31 / auto-attack

Code relative to "Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks"
https://arxiv.org/abs/2003.01690
MIT License
639 stars 111 forks source link

gradient computation issue #87

Closed CNOCycle closed 2 years ago

CNOCycle commented 2 years ago

Hi authors,

I have a minor concern for the loss computation. The following snippet shows the part implementation of APGD attack:

        x_adv.requires_grad_()
        grad = torch.zeros_like(x)
        for _ in range(self.eot_iter):
            if not self.is_tf_model:
                with torch.enable_grad():
                    logits = self.model(x_adv)
                    loss_indiv = criterion_indiv(logits, y)
                    loss = loss_indiv.sum()

                grad += torch.autograd.grad(loss, [x_adv])[0].detach()
            else:
                if self.y_target is None:
                    logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y)
                else:
                    logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y,
                        self.y_target)
                grad += grad_curr

        grad /= float(self.eot_iter)
        grad_best = grad.clone()

The loss is an aggregate of each example's loss and the scale of the corresponding gradient is depended on the batch size. It may cause the gradient explosion in some occasional cases. For example, the batch size is quite large or the gradient is stored in mixed precision. I did not read the details of FAB and SQUARE attacks but I guess those two attacks face same issue.

I propose that the loss should be divided by the batch size(loss = loss_indiv.sum() / y.shape[0]). For L-inf norm, the perturbation is torch.sign(grad). Because the gradient is normalized before updating image, the modification does not affect AA accuracy and the cases I mentioned earlier can be fixed as well.

fra31 commented 2 years ago

Hi,

interesting point. I think using loss_indiv.mean() instead of loss_indiv.sum() wouldn't be a problem, but actually I've never experienced such issue with the sum. Rescaling the gradient shouldn't be a problem since also for the other threat models it is normalized (and in principle one can rescale it back after it's computed). The only doubt I have is whether, for large batch size, the opposite effect can happen, i.e. the loss becomes for some point sufficiently small to be approximated with zero (which wouldn't allow gradient computation). I guess this is as rare as the explosion case. Any thoughts on this? Maybe a more refined solution is possible.

CNOCycle commented 2 years ago

As far as I know, the gradient is independent of the batch size in gradient decent. Especially, when weight decay or regularization terms are included in the loss, the importance of those terms should be scaled by the batch size as well if the loss is summed. Moreover, the final gradient is divided by eot_iter instead of summed. From my opinion of view, the average is much general.

Regard of the gradient vanish, a detection mechanism has been implemented in APGD.

        if self.loss in ['dlr', 'dlr-targeted']:
            # check if there are zero gradients
            check_zero_gradients(grad, logger=self.logger)

One of elegant solution is the dynamic scaling which has been implemented in automatic mixed precision by Pytorch and TF. The core concept is shown in the following snippet:

self.init_scale = 1

scale = self.init_scale
loss = scale * loss
gradient = torch.autograd.grad(loss, [x])

while gradient vanish:
    self.init_scale= self.init_scale * 2
    loss = loss * 2
    gradient = torch.autograd.grad(loss, [x])

while gradient explosion:
    self.init_scale= self.init_scale / 2.0
    loss = loss / 2.0
    gradient = torch.autograd.grad(loss, [x])

However, I would agree that the dynamic scaling makes the code more complex and not worth implementing. For those rare cases, AA can simply raise a warning when unreliable gradient is detected and abort immediately.

CNOCycle commented 2 years ago

Any progress for this issue? Definitely, this is just a minor concern. Feel free to close this issue if no further updates are required.

fra31 commented 2 years ago

I think this is something worth keeping in mind, but, as far as I understand, it hasn't happened so far. Then, I wouldn't modify the implementation for the moment, but I'm ready to change it in case somebody experiences gradient explosion.