pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.67k stars 332 forks source link

AttributeError: 'Parameter' object has no attribute 'grad_sample' in Projected GAN #548

Open sword-king1 opened 1 year ago

sword-king1 commented 1 year ago

🐛 Bug

When I applied opacus to the code of the Projected GAN, I had this problem: “AttributeError: 'Parameter' object has no attribute 'grad_sample'”.I've replaced batch_norm with group_norm for the discriminator module, but the error persists.this is the trace:

File "train.py", line 267, in main() # pylint: disable=no-value-for-parameter File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/click/core.py", line 1128, in call return self.main(args, kwargs) File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/click/core.py", line 1053, in main rv = self.invoke(ctx) File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/click/core.py", line 1395, in invoke return ctx.invoke(self.callback, ctx.params) File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/click/core.py", line 754, in invoke return __callback(args, kwargs) File "train.py", line 253, in main launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run) File "train.py", line 101, in launch_training subprocess_fn(rank=0, c=c, temp_dir=temp_dir) File "train.py", line 47, in subprocess_fn training_loop.training_loop(rank=rank, c) File "/mnt/LJH/wd/test/training/training_loop.py", line 410, in training_loop loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg) File "/mnt/LJH/wd/test/training/loss.py", line 86, in accumulate_gradients loss_Dgen.backward() File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/torch/autograd/init.py", line 173, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/mnt/LJH/wd/.conda/envs/new/lib/python3.8/site-packages/opacus/privacy_engine.py", line 71, in forbid_accumulation_hook if p.grad_sample is not None: AttributeError: 'Parameter' object has no attribute 'grad_sample'

Since this code is a bit complicated, I will explain it here for your convenience. The original ProjectedGAN contained a generator and a Projected Discriminator. Here Projected Discriminator contains a feature_network (pre-trained network, CCM, and CSM) and four discriminators that provide image features to the discriminators, feature_network will not be updated by training, but only four discriminators. Instead of splitting the structure and loss function of the original multiple discriminator mergers, I directly processed each discriminator using opacus, and when trained on Projected Discriminator, optimized using four opts returned by opacus. These are all changed in the training loop file ...

image

Because the original code of ProjectedGAN needs to be serialized, but the wrap_collate_with_empty() in the data_loader file of opacus contains another function:collate(batch),so I made some small changes to the code to change the case of the function closure.You need to replace data_loader files in the original opacus package

To Reproduce

1.Use data_loader file in github to replace the origin data_loader file of opacus 2.python dataset_tool.py --source=./data --dest=./data/beauty256.zip --resolution=256x256 --transform=center-crop 3.python train.py --outdir=./training-runs/ --cfg=fastgan --data=./data/beauty256.zip --gpus=1 --batch=32 --mirror=1 --snap=50 --batch-gpu=16 --kimg=600

github repository

Sorry, I tried to reproduce the code on colab, but probably because I haven't used colab before and got some errors. here is my code: https://github.com/sword-king1/ProjectedGAN

Environment

alexandresablayrolles commented 1 year ago

Thanks for raising this issue. This problem is usually due to doing more forward passes than backward passes. I suspect something like this is going on here. I would advise to write the network you showed as one module containing everything including D1-D4, and give that module to the PrivacyEngine. (Let me mention also functorch, even though it requires significant changes to your codebase: you can use functorch to compute per-sample gradients and feed them directly to Opacus, as shown in this example)