IntelLabs / MART

Modular Adversarial Robustness Toolkit
BSD 3-Clause "New" or "Revised" License
17 stars 0 forks source link

Adversary as `pl.LightningModule` #103

Closed dxoigmn closed 1 year ago

dxoigmn commented 1 year ago

What does this PR do?

Right now we treat adversaries as special things with their own loops and callbacks, when really we should just treat them like LightningModules. Doing so means that we can just use a Trainer to fit its parameters. This PR attempts to make that so.

As of dcf7599, there is a bug in adversarial training.

Depends upon #146 and #147.

Type of change

Please check all relevant options.

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

Before submitting

Did you have fun?

Make sure you had fun coding 🙃

dxoigmn commented 1 year ago

I started splitting this PR into pieces. I think I'll also try to stick to the same config structure.

dxoigmn commented 1 year ago

There is a bug introduced in b6afdd14b276987d821905a1844ec7c4df611621 with adversarial training of CIFAR10. I'm not sure where it is but this should be a good opportunity to write a test :)

Found it and wrote a test!

dxoigmn commented 1 year ago

One thing that may be worth pull out of this PR is the change to gain.

Done.

dxoigmn commented 1 year ago

I started breaking up this PR into smaller PRs hence the dismissed review.

dxoigmn commented 1 year ago

I made another batch of changes to this PR to accommodate the changes in #132. As such, the tests need to be fixed. Additionally, there are some other changes related to types (e.g., relying upon Iterable[torch.Tensor] instead of tuple) that can be pulled out and made more generally useful. I created #134 for those changes.

mzweilin commented 1 year ago

I tried to run adversarial training on 2 GPUs but failed. Then hiding perturber parameters accidentally resolved the issue.

python -m mart \
experiment=CIFAR10_CNN_Adv \
trainer=ddp \
trainer.devices=2 \
model.optimizer.lr=0.2 \
trainer.max_steps=2925 \
datamodule.ims_per_batch=256
  File "/home/weilinxu/coder/MART/.venv/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 807, in <listcomp>
    for param_name, param in module.named_parameters(recurse=False)
  File "/home/weilinxu/coder/MART/mart/attack/perturber.py", line 83, in named_parameters
    raise MisconfigurationException("You need to call configure_perturbation before fit.")
pytorch_lightning.utilities.exceptions.MisconfigurationException: You need to call configure_perturbation before fit.