pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

Wrapper references can be easily replaced, consider using properties instead #649

Open andreatgretel opened 2 months ago

andreatgretel commented 2 months ago

(apologies for skipping template, not properly reporting a bug but rather suggesting an improvement)

Currently, the DPOptimizer passes through the state, default and param_groups attributes by simply referencing them. This can be an issue when another object tries to set these attributes from the outside, which might cause the reference to be replaced, instead of changing original_optimizer's attributes. For instance, this happens when HF's AcceleratedOptimizer wraps the DPOptimizer, causing issues such as this one.

A safer alternative might be to do what the AcceleratedOptimizer itself does, using properties instead of passing by reference. For instance, for param_groups, that would look something like

@property
def param_groups(self):
    return self.original_optimizer.param_groups

@param_groups.setter
def param_groups(self, param_groups):
    self.original_optimizer.param_groups = param_groups

Reproducing etc.

Please take a look at this Github issue to see a particular situation where passing by reference could become a problem. In summary: HF's accelerate will put yet another wrapper around the DPOptimizer, and ideally this "doubly-wrapped" optimizer should still function properly. However, due to the AcceleratedOptimizer attempt to change the whole param_groups attribute instead of its contents, changes will never get to the original_optimizer.

Consider the follow MWE of how this could happen

from copy import deepcopy

import opacus
import torch

m = torch.nn.Linear(10, 1)
o = torch.optim.AdamW(m.parameters())
odp = opacus.optimizers.DPOptimizer(o, noise_multiplier=0.5, max_grad_norm=0.1, expected_batch_size=32)

pg = deepcopy(odp.param_groups)
pg[0]['lr'] = 1.0
odp.param_groups = pg

print('LR at DPOptimizer:', odp.param_groups[0]['lr'])
print('LR at original_optimizer:', odp.original_optimizer.param_groups[0]['lr'])
HuanyuZhang commented 1 month ago

Thanks for reporting this. We will investigate and launch a fix.