wenzhu23333 / Differential-Privacy-Based-Federated-Learning

Everything you want about DP-Based Federated Learning, including Papers and Code. (Mechanism: Laplace or Gaussian, Dataset: femnist, shakespeare, mnist, cifar-10 and fashion-mnist. )
GNU General Public License v3.0
348 stars 55 forks source link

Differences between optimizer.zero_grad() and net.zero_grad() #6

Closed FassyGit closed 1 year ago

FassyGit commented 1 year ago

Hi wenzhu,

Thanks a lot for this nice code! I found that with the following code optimizer = torch.optim.SGD(optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)) and in the training step, to zero the gradients of the model, optimzer.zero_grad(), consumes much more resources compared with net.zero_grad(). Generally speaking, optimizer.zero_grad() and net.zero_grad() is equivalent, with the above definition of optimizer. Is it the problem caused by packalge opacus? I am not quite familar with opacus.

wenzhu23333 commented 1 year ago

Thanks for your question, I have also encountered this problem. Your guess is right, this is caused by the package opacus. Because we have added a wrapper to the global model through the GradSampleModule, the reason why the Wrap global model is to allow the model to expose the gradient calculated by each piece of data (ie per sample gradient) during the SGD training process, the pytoch native library does not support this Function.

When the wrapper is added, the gradient of the model will have an attribute of grad_sample during the training process, which records the gradient calculated for each piece of data, allowing us to perform clip operations on these data.

Since we used the GradSampleModule to add a wrapper to the global model, at this time The two functions optimizer.zero_grad() and net.zero_grad() are not equivalent. Since net.zero_grad() calls the function of the model, after wrapper processing, grad and grad_sample will be cleared at the same time before each training. However, because there is no wrapper, optimizer.zero_grad() will only clear the grad, and the grad_sample is saved by default, and all clients will store this information in each round, which causes a waste of resources.

If you need to use optimizer.zero_grad(), you can consider manually emptying grad_sample.

Hope the above answer is useful to you.

Reference: https://opacus.ai/tutorials/guide_to_grad_sampler