cleverhans-lab / cleverhans

An adversarial example library for constructing attacks, building defenses, and benchmarking both
MIT License
6.2k stars 1.39k forks source link

Attacks in Jax only support cross entropy loss #1184

Open ZuowenWang0000 opened 3 years ago

ZuowenWang0000 commented 3 years ago

The fast_gradient_method in the Jax implementation is now by default using cross-entropy loss for crafting adversarial examples: https://github.com/cleverhans-lab/cleverhans/blob/4b5ce5421ced09e8531b112f97468869980884f2/cleverhans/future/jax/attacks/fast_gradient_method.py#L40

It is apparently now always correct to assume people are using cross-entropy loss.

Describe the solution you'd like The most straight forward solution would be to pass the loss function being used as an extra parameter to both fgsm and pgd functions. This would be also coherent with attacks implemented in other frameworks such as in tf: https://github.com/cleverhans-lab/cleverhans/blob/4b5ce5421ced09e8531b112f97468869980884f2/cleverhans/attacks/fast_gradient_method.py#L58

An alternative would be instead of passing the predict function, we pass a model object which has the predict function and loss function registered.

jonasguan commented 3 years ago

Thanks for the suggestion @ZuowenWang0000! If you can submit a PR with your proposed changes, we'd be glad to review and merge it.

jonasguan commented 3 years ago

Oops, did not mean to close this yet.