Closed dxoigmn closed 1 year ago
I started splitting this PR into pieces. I think I'll also try to stick to the same config structure.
There is a bug introduced in b6afdd14b276987d821905a1844ec7c4df611621 with adversarial training of CIFAR10. I'm not sure where it is but this should be a good opportunity to write a test :)
Found it and wrote a test!
One thing that may be worth pull out of this PR is the change to gain
.
Done.
I started breaking up this PR into smaller PRs hence the dismissed review.
I made another batch of changes to this PR to accommodate the changes in #132. As such, the tests need to be fixed. Additionally, there are some other changes related to types (e.g., relying upon I created #134 for those changes.Iterable[torch.Tensor]
instead of tuple
) that can be pulled out and made more generally useful.
I tried to run adversarial training on 2 GPUs but failed. Then hiding perturber parameters accidentally resolved the issue.
python -m mart \
experiment=CIFAR10_CNN_Adv \
trainer=ddp \
trainer.devices=2 \
model.optimizer.lr=0.2 \
trainer.max_steps=2925 \
datamodule.ims_per_batch=256
File "/home/weilinxu/coder/MART/.venv/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 807, in <listcomp>
for param_name, param in module.named_parameters(recurse=False)
File "/home/weilinxu/coder/MART/mart/attack/perturber.py", line 83, in named_parameters
raise MisconfigurationException("You need to call configure_perturbation before fit.")
pytorch_lightning.utilities.exceptions.MisconfigurationException: You need to call configure_perturbation before fit.
What does this PR do?
Right now we treat adversaries as special things with their own loops and callbacks, when really we should just treat them like LightningModules. Doing so means that we can just use a Trainer to fit its parameters. This PR attempts to make that so.
As of dcf7599, there is a bug in adversarial training.Depends upon #146 and #147.
Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
pytest
python -m mart experiment=CIFAR10_CNN_Adv trainer=gpu
achieves 71% accuracy.python -m mart experiment=CIFAR10_CNN_Adv trainer=ddp datamodule.world_size=2 trainer.devices=2
achieves 71% accuracy.Before submitting
pre-commit run -a
command without errorsDid you have fun?
Make sure you had fun coding 🙃