Harry24k / adversarial-attacks-pytorch

PyTorch implementation of adversarial attacks [torchattacks].
https://adversarial-attacks-pytorch.readthedocs.io/en/latest/index.html
MIT License
1.79k stars 337 forks source link

CW efficiency improvement and bug fix, add CW binary search version, early stop PGD version, support `L0` and `Linf` for CW and CWBS, rewrite FAB attack, fix MI-FGSM bug, rewrite JSMA. #168

Open rikonaka opened 7 months ago

rikonaka commented 7 months ago

PR Type and Checklist

What kind of change does this PR introduce?

CW attack fix

There is an obscure bug in the original CW attack code F function.

In CW original code from Carlini, the real is calculate as

other

https://github.com/carlini/nn_robust_attacks/blob/c6b8f6a254e82a79a52cfbc673b632cad5ea1ab1/l2_attack.py#L96

It was a sum, but in torchattacks, it become max, I discovered this problem accidentally πŸ˜‹.

https://github.com/Harry24k/adversarial-attacks-pytorch/blob/936e86d6387ef5ca57e4114d83745cdf199b46cf/torchattacks/attacks/cw.py#L136

I also reduced the large number of tensor detech() operations and view() operations in the original code, instead used index to assign tensors, its more simple and efficiency.

At the same time, I also added the binary search version of CW (CWBS), issues https://github.com/Harry24k/adversarial-attacks-pytorch/issues/167 . Binary search can indeed significantly reduce the size of the perturbations. The red line is the value of best_L2.

best_L2

I tested three cw attack algorithms L0, L2 and Linf and found that 100% attack success rate can be achieved on 50 test images.

attack rate

And its pertubations is still invisible.

show

FAB attack fix

The original FAB code was too complicated and difficult to maintain, so I rewritten the FAB attack and split L1, L2 attacks into separate files, and I found that previous FAB code when the user specifies a target label, it does not work good with the target attack.

The old FAB code is rename as AFAB so that it could be used in autoattack.

In the FAB code forward() function

https://github.com/Harry24k/adversarial-attacks-pytorch/blob/23620a694a3660e4f194c3e4d28992bced7785a1/torchattacks/attacks/fab.py#L84

There are no parameters for the target label, in contrast, the FAB target attack requires both labels, one for the original label and the other for the target label.

https://github.com/Harry24k/adversarial-attacks-pytorch/blob/23620a694a3660e4f194c3e4d28992bced7785a1/torchattacks/attacks/fab.py#L127

But there is only one label entered in the entire code. If the user wants to specify the target label to be used for the attack, since there is only one label input, the computation of the code related to the target attack will actually be meaningless.

https://github.com/Harry24k/adversarial-attacks-pytorch/blob/23620a694a3660e4f194c3e4d28992bced7785a1/torchattacks/attacks/fab.py#L132

For example, here la=la_target, then diffy here is meaningless.

I'll try to fix this, but don't have any clue at the moment because we need to enter two labels for the attack, which conflicts with the existing framework. So first submitted the FAB attack without the target attack version now.

FAB target attack has been completed.

codecov-commenter commented 7 months ago

:warning: Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 82.13115% with 327 lines in your changes missing coverage. Please review.

Project coverage is 76.85%. Comparing base (936e86d) to head (a40c9a4). Report is 1 commits behind head on master.

:exclamation: Current head a40c9a4 differs from pull request most recent head 8e6815c

Please upload reports for the commit 8e6815c to get more accurate results.

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168/graphs/tree.svg?width=650&height=150&src=pr&token=00CQ79UTC2&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry)](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry) ```diff @@ Coverage Diff @@ ## master #168 +/- ## ========================================== + Coverage 73.37% 76.85% +3.48% ========================================== Files 44 54 +10 Lines 3827 4926 +1099 Branches 578 631 +53 ========================================== + Hits 2808 3786 +978 - Misses 862 974 +112 - Partials 157 166 +9 ``` | [Files](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry) | Coverage Ξ” | | |---|---|---| | [code\_coverage/test\_atks.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=code_coverage%2Ftest_atks.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-Y29kZV9jb3ZlcmFnZS90ZXN0X2F0a3MucHk=) | `100.00% <100.00%> (+6.89%)` | :arrow_up: | | [torchattacks/\_\_init\_\_.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2F__init__.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL19faW5pdF9fLnB5) | `100.00% <100.00%> (ΓΈ)` | | | [torchattacks/attacks/autoattack.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fautoattack.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvYXV0b2F0dGFjay5weQ==) | `80.64% <100.00%> (ΓΈ)` | | | [torchattacks/attacks/cw.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fcw.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvY3cucHk=) | `100.00% <100.00%> (ΓΈ)` | | | [torchattacks/attacks/cwbs.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fcwbs.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvY3dicy5weQ==) | `100.00% <100.00%> (ΓΈ)` | | | [torchattacks/attacks/cwl0.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fcwl0.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvY3dsMC5weQ==) | `100.00% <100.00%> (ΓΈ)` | | | [torchattacks/attacks/mifgsm.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fmifgsm.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvbWlmZ3NtLnB5) | `100.00% <100.00%> (ΓΈ)` | | | [torchattacks/attacks/cwbsl0.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fcwbsl0.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvY3dic2wwLnB5) | `98.97% <98.97%> (ΓΈ)` | | | [torchattacks/attacks/cwbslinf.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fcwbslinf.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvY3dic2xpbmYucHk=) | `98.95% <98.95%> (ΓΈ)` | | | [torchattacks/attacks/cwlinf.py](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree&filepath=torchattacks%2Fattacks%2Fcwlinf.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry#diff-dG9yY2hhdHRhY2tzL2F0dGFja3MvY3dsaW5mLnB5) | `98.52% <98.52%> (ΓΈ)` | | | ... and [10 more](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry) | | ... and [4 files with indirect coverage changes](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry) ------ [Continue to review full report in Codecov by Sentry](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?dropdown=coverage&src=pr&el=continue&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry). > **Legend** - [Click here to learn more](https://docs.codecov.io/docs/codecov-delta?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry) > `Ξ” = absolute (impact)`, `ΓΈ = not affected`, `? = missing data` > Powered by [Codecov](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?dropdown=coverage&src=pr&el=footer&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry). Last update [23620a6...8e6815c](https://app.codecov.io/gh/Harry24k/adversarial-attacks-pytorch/pull/168?dropdown=coverage&src=pr&el=lastupdated&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry). Read the [comment docs](https://docs.codecov.io/docs/pull-request-comments?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Harry).
ZaberKo commented 7 months ago

I think the calculation of other is still incorrect, which neglects that output logits could be negative numbers. cwl2.py#L146

rikonaka commented 7 months ago

I think the calculation of other is still incorrect, which neglects that output logits could be negative numbers. cwl2.py#L146

Thank you very much for your advice, but this other calculation is actually translated from carlini's code (Tensorflow to Pytorch). You can check it out 😁.

https://github.com/carlini/nn_robust_attacks/blob/c6b8f6a254e82a79a52cfbc673b632cad5ea1ab1/l2_attack.py#L97

And you mentioned that logits may be negative, the original author's code also directly used the value before softmax.

https://github.com/carlini/nn_robust_attacks/blob/c6b8f6a254e82a79a52cfbc673b632cad5ea1ab1/l2_attack.py#L90C33-L90C33

real

So this should be correct πŸ˜‰.

ZaberKo commented 7 months ago

I think the calculation of other is still incorrect, which neglects that output logits could be negative numbers. cwl2.py#L146

Thank you very much for your advice, but this other calculation is actually translated from carlini's code (Tensorflow to Pytorch). You can check it out 😁.

https://github.com/carlini/nn_robust_attacks/blob/c6b8f6a254e82a79a52cfbc673b632cad5ea1ab1/l2_attack.py#L97

And you mentioned that logits may be negative, the original author's code also directly used the value before softmax.

https://github.com/carlini/nn_robust_attacks/blob/c6b8f6a254e82a79a52cfbc673b632cad5ea1ab1/l2_attack.py#L90C33-L90C33

real

So this should be correct πŸ˜‰.

Thanks for the quick response. I think you misunderstand the issue. A quick fix of cwl2.py#L146 would be like:

other = torch.max((1 - one_hot_labels) * outputs - one_hot_labels*10000., dim=1)[0]
rikonaka commented 7 months ago

Thanks for the quick response. I think you misunderstand the issue. A quick fix of cwl2.py#L146 would be like:

other = torch.max((1 - one_hot_labels) * outputs - one_hot_labels*10000., dim=1)[0]

Good question, well, in here we will pick the maximum value of the logits except true label, so if here we only have 1 images, the outputs will be

[
[x1, x2, x3, x4]
]

Then we used the one_hot_labels to mask one positon (suppose the x3), we will got

[
[x1, x2, 0, x4]
]

So the torch.max will caculater the max value of x1, x2, 0 and x3.

In Tensorflow, the original author subtracts that value (one_hot_labels*10000) to prevent the all logits are negative (I haven't used tensorflow for a long time 🀣), this is a point that can be improved. But In pytroch, and in here logits is greater than 0.

logits

So the situation where all logits are negative that you are worried about will not happen πŸ˜‰.

ZaberKo commented 7 months ago

Thanks for the quick response. I think you misunderstand the issue. A quick fix of cwl2.py#L146 would be like:

other = torch.max((1 - one_hot_labels) * outputs - one_hot_labels*10000., dim=1)[0]

Good question, well, in here we will pick the maximum value of the logits except true label, so if here we only have 1 images, the outputs will be

[
[x1, x2, x3, x4]
]

Then we used the one_hot_labels to mask one positon (suppose the x3), we will got

[
[x1, x2, 0, x4]
]

So the torch.max will caculater the max value of x1, x2, 0 and x3.

In Tensorflow, the original author subtracts that value (one_hot_labels*10000) to prevent the all logits are negative (I haven't used tensorflow for a long time 🀣), this is a point that can be improved. But In pytroch, and in here logits is greater than 0.

logits

So the situation where all logits are negative that you are worried about will not happen πŸ˜‰.

However, there is no such guarantee that the output logits must be non-negtive in pytorch, for arbitrary models under any training methods.

rikonaka commented 7 months ago

However, there is no such guarantee that the output logits must be non-negtive in pytorch, for arbitrary models under any training methods.

πŸ˜΅β€πŸ’« The same, there is also no such guarantee that the output logits must be negative in pytorch, for arbitrary models under any training methods. If you can provide any evidence that the logits output of some model is all negative, it may be able to further support your argument.

ZaberKo commented 7 months ago

However, there is no such guarantee that the output logits must be non-negtive in pytorch, for arbitrary models under any training methods.

πŸ˜΅β€πŸ’« The same, there is also no such guarantee that the output logits must be negative in pytorch, for arbitrary models under any training methods. If you can provide any evidence that the logits output of some model is all negative, it may be able to further support your argument.

That is not the point. The point here is that we need to cover all cases, even though some of them are rare. Here are some other implementations of CW f_func in pytorch for reference:

rikonaka commented 7 months ago

However, there is no such guarantee that the output logits must be non-negtive in pytorch, for arbitrary models under any training methods.

πŸ˜΅β€πŸ’« The same, there is also no such guarantee that the output logits must be negative in pytorch, for arbitrary models under any training methods. If you can provide any evidence that the logits output of some model is all negative, it may be able to further support your argument.

That is not the point. The point here is that we need to cover all cases, even though some of them are rare. Here are some other implementations of CW f_func in pytorch for reference:

* [imrahulr/adversarial_robustness_pytorch](https://github.com/imrahulr/adversarial_robustness_pytorch/blob/6df6a8f0cd49cf6d18507a4b574c004ab6eedf49/core/attacks/utils.py#L212)

* [thu-ml/ares](https://github.com/thu-ml/ares/blob/306e35fe4309d791f9252bb6aab51198d2b9b511/ares/attack/cw.py#L133)

Thanks for your suggestion πŸ‘, I will rewrite this f function quickly. Next time, please provide detailed information directly from the beginning, instead of wasting other people's time by making people guess and misunderstand of your short information.

Adversarian commented 7 months ago

Thanks for the effort you made to improve the implementation of CW in this library. I had one suggestion, and correct me if it is not feasible to implement, but wouldn't it be better if you aliased one of the variants of CW (e.g. CWL0 or CWLinf etc.) as CW so that this version doesn't introduce a breaking change for torchattacks.CW to preserve backward compatibility?

You could use the version of CW that was previously used (I believe CWL2 in the current implementation) as an alias to remediate this (as easily as something like CW = CWL2 for instance).

rikonaka commented 7 months ago

Thanks for the effort you made to improve the implementation of CW in this library. I had one suggestion, and correct me if it is not feasible to implement, but wouldn't it be better if you aliased one of the variants of CW (e.g. CWL0 or CWLinf etc.) as CW so that this version doesn't introduce a breaking change for torchattacks.CW to preserve backward compatibility?

You could use the version of CW that was previously used (I believe CWL2 in the current implementation) as an alias to remediate this (as easily as something like CW = CWL2 for instance).

Thank you very much for your suggestion. I will move CWL2 to CW now. πŸ˜‰