Open ZuowenWang0000 opened 3 years ago
The fast_gradient_method in the Jax implementation is now by default using cross-entropy loss for crafting adversarial examples: https://github.com/cleverhans-lab/cleverhans/blob/4b5ce5421ced09e8531b112f97468869980884f2/cleverhans/future/jax/attacks/fast_gradient_method.py#L40
It is apparently now always correct to assume people are using cross-entropy loss.
Describe the solution you'd like The most straight forward solution would be to pass the loss function being used as an extra parameter to both fgsm and pgd functions. This would be also coherent with attacks implemented in other frameworks such as in tf: https://github.com/cleverhans-lab/cleverhans/blob/4b5ce5421ced09e8531b112f97468869980884f2/cleverhans/attacks/fast_gradient_method.py#L58
An alternative would be instead of passing the predict function, we pass a model object which has the predict function and loss function registered.
Thanks for the suggestion @ZuowenWang0000! If you can submit a PR with your proposed changes, we'd be glad to review and merge it.
Oops, did not mean to close this yet.
The fast_gradient_method in the Jax implementation is now by default using cross-entropy loss for crafting adversarial examples: https://github.com/cleverhans-lab/cleverhans/blob/4b5ce5421ced09e8531b112f97468869980884f2/cleverhans/future/jax/attacks/fast_gradient_method.py#L40
It is apparently now always correct to assume people are using cross-entropy loss.
Describe the solution you'd like The most straight forward solution would be to pass the loss function being used as an extra parameter to both fgsm and pgd functions. This would be also coherent with attacks implemented in other frameworks such as in tf: https://github.com/cleverhans-lab/cleverhans/blob/4b5ce5421ced09e8531b112f97468869980884f2/cleverhans/attacks/fast_gradient_method.py#L58
An alternative would be instead of passing the predict function, we pass a model object which has the predict function and loss function registered.