alessandrodepalma / expressive-losses

Expressive Losses for Verified Robustness via Convex Combinations
https://openreview.net/pdf?id=mzyZ4wzKlM
MIT License
4 stars 0 forks source link

Random Batch Norm Statistics? #1

Closed AlgebraLoveme closed 7 months ago

AlgebraLoveme commented 7 months ago

Dear authors,

Congratulations on the publication!

As noted in the appendix F.2 and the comment below https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L136 batch norm statistics are supposed to be set based on clean inputs. However, https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L123 here it sets the batch norm layers to train mode, followed by an evaluation of the model on the adversarial inputs. https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L126 This means the adversarial inputs will reset the batch norm statistics.

In addition, it switches the batch norm layers to eval https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L110 before running PGD attack on it. This will change the network: the clean input is evaluated with batch statistics, while the attack is based on accumulated statistics, i.e., running mean and running var.

In conclusion, it seems that the network used for clean input (batch stat) is different from the attacked network (running stat), and the network used for IBP (running stat + adv stat) is different from both. Could you provide some explanations for this?

Best

alessandrodepalma commented 7 months ago

Hi Yuhao,

Thank you for your continued interest in our work.

BatchNorm stats for train-time IBP.

When in training mode (it is the case, when we compute the bounds at training time), AutoLiRPA's IBP implementation uses input batch statistics. By default, these are taken by the last forward pass. However, as we pass the clean input to the bounding computation (see code pointers below), the IBP computation will use batch statistics on the provided clean input (see https://github.com/Verified-Intelligence/auto_LiRPA/blob/2553832b5a5bbfe643b694458867ebd1dbdece65/auto_LiRPA/bound_general.py#L971).

https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L99 https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L140-L141 https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/certified.py#L28

Hence, as described in appendix F.2, and as suggested by the FastIBP authors, train-time IBP computations use the clean input batch statistics.

BatchNorm stats for train-time forward passes.

Consistently with SABR, the clean input evaluation uses its own batch statistics, and the same holds for the adversarial input evaluation.

Eval-time statistics.

As we state in appendix F.2, "The statistics used at evaluation time are instead computed including both the original data and the attacks used during training.". At eval-time of course the same statistics are used for all forward passes, attacks and bounding computations. The idea behind computing running stats also on the adversarial points is that the network is trained for robustness, and as such it should also expect adversaries at evaluation time.

Running the train-time attacks in eval mode.

We are only interested in the statistics from the output of the attacks (we do not want to accumulate statistics on the intermediate attack steps, as that would for instance include the random init), so the network is kept in evaluation mode during the attack process itself. This is not uncommon: as you may know, SABR performs the attacks in eval mode too.

Best, Alessandro

AlgebraLoveme commented 7 months ago

Hi Alessandro,

Thanks for the reply!

First, I would like to point out that SABR does not evaluate its adversarial attack based on adversarial batch statistics. As shown below, https://github.com/eth-sri/SABR/blob/1a6c2582ec9fa2d81eacdd3c86f90732bb72f4e7/src/train.py#L88 instead of calling model.eval(), it writes a helper function to enforce the model to keep using the original batch stat without modification to its running stat and thus keeping the attacked model the same as the model used for clean inputs. Therefore, it is indeed uncommon (at least when doing certified training) to call model.eval() which resets the batch statistics to the running statistics.

Second, I agree that auto_LiRPA reuses the batch stat of last forward pass. However, as I said in the first comment, https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L126 here the model has a forward pass for adversarial inputs before calling IBP bounds https://github.com/alessandrodepalma/expressive-losses/blob/99656729c3dfd385e62eb9c171628a02c0c8276f/train.py#L140 Therefore, the IBP bounds are computed based on the mix of batch statistics between clean and adversarial inputs, which is different from both the net of clean forward pass and the net used for attacks.

Best

alessandrodepalma commented 7 months ago

Thanks a lot, Yuhao, for your reply and the interesting discussion.

Maybe I am misunderstanding something, but it seems to me that bn_mode_attack is True for the provided SABR hyper-parameters, hence running this statement: https://github.com/eth-sri/SABR/blob/1a6c2582ec9fa2d81eacdd3c86f90732bb72f4e7/src/train.py#L85 And in evaluation mode use_old_train_stats is not used: https://github.com/eth-sri/SABR/blob/1a6c2582ec9fa2d81eacdd3c86f90732bb72f4e7/src/AIDomains/abstract_layers.py#L633 The fact that SABR runs the attacks in evaluation mode is further confirmed by the appendix of the original paper (appendix C, "architectures" paragraph).

In any case, your enquiries pushed me to test the behaviour of autoLiRPA empirically, and indeed autoLiRPA behaves somewhat unexpectedly with respect to train-time BatchNorm statistics for IBP bounds. Instead of using those from the point passed with the compute_bounds call, which was my previous understanding, it simply uses the statistics from the last forward pass, even in the case when the last forward pass was in eval mode. As a result, and differently from what was written in appendix F.2, the train-time IBP bounds are computed using (only) the statistics from the perturbed input batch. In any case, these are the same statistics employed to compute the train-time adversarial logit differences and loss (this forward pass is instead not used in the SABR loss so there is no direct precedent). Also note that, after the warm-up phase, the clean loss is not part of the training loss.

In summary, both the train-time IBP and adversarial logits/losses are computed using only stats from the perturbed points, and the attack itself is computed in evaluation mode. The evaluation statistics are computed using both perturbed and unperturbed points, consistently with our original appendix description. I do not think there is a single "correct" way to handle this. We were consistent throughout the presented experiments.

I edited the appendix to clarify these points (on OpenReview), and added you to the acknowledgments for it. I will also edit the comment in the code. Thanks a lot for pointing me to this!

Best, Alessandro

AlgebraLoveme commented 7 months ago

Hi Alessandro,

Indeed, you were right about SABR behavior. I have consulted SABR authors and confirmed this. Thanks for pointing this out.

I really appreciate your new clarification. This convinces me that the batch norm statistics are consistent, at least for PGD loss and IBP loss, although the attacked network is somehow different. The only difference is that the batch norm statistics are all based on adversarial inputs instead of clean inputs.

Thanks, Yuhao