pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.67k stars 332 forks source link

Opacus report an error when freezing weights but preserving bias gradients. It is recommended to add the weight freezing function. #542

Closed xtyXpastor closed 1 year ago

xtyXpastor commented 1 year ago

🚀 Feature

Motivation

I want to fix the weights of my classifier (the classifier is a nn.linear() ) while training only the bias of that classifier.

Pitch

When I did this to the above, I noticed that Opacus was throwing a dimension mismatch error (which is natural, since the private sampler is using Poisson sampling). This makes me think that maybe some of the weight caches that don't need to use gradients are not deleted in multiple iterations. Specifically, I found that this cache is p._current_grad_sample.

When the weight does not require gradient, its del p._current_grad_sample will not be triggered. Then I went to find out why the weight doesn't require gradient, but his _current_grad_sample is still generated. Note that in /opacus/grad_sample/linear.py, the generation of ret must contain gradients to the weights even if the weight gradients are not required. I think this is the source of the above problem. I avoid this problem by simple judgment. if layer.weight.requires_grad == False and layer.bias is not None: ret = {layer.bias: torch.einsum("n...k->nk", backprops) } else: gs = torch.einsum("n...i,n...j->nij", backprops, activations) ret = {layer.weight:gs} if layer.bias is not None: ret[layer.bias] = torch.einsum("n...k->nk", backprops) But I still hope that the opacus official can make a better adaptation to the weight freezing behavior.

Alternatives

Additional context

pierrestock commented 1 year ago

Hey xtyXpastor,

Thanks for your interst in Opacus. Do you have a minimal reproduction example? Do you freeze a part of your network after some training iterations? Both cases seem to work fine on a simple example as detailed below.

import torch
import torch.nn as nn
from opacus import PrivacyEngine 

# define your components as usual
model = nn.Linear(2, 3, bias=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
dataset = torch.utils.data.TensorDataset(torch.rand(100, 2))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=16)

# enter PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=data_loader,
    noise_multiplier=1.1,
    max_grad_norm=1.0,
    poisson_sampling=True
)
# now it's business as usual
for i, x in enumerate(data_loader):
    if i > 1:
        model._module.weight.requires_grad = False 
    optimizer.zero_grad()
    loss = model(x[0]).sum()
    loss.backward()
    optimizer.step()

Hope this helps, Pierre

xtyXpastor commented 1 year ago

I am very happy to receive feedback from the opacus community. My repro code is very close to yours. In fact, I just freeze the weights of the linear layer before calling opacus.

import torch
import torch.nn as nn
from opacus import PrivacyEngine
model = nn.Linear(2, 3, bias=True)
model.weight.requires_grad = False
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
dataset = torch.utils.data.TensorDataset(torch.rand(100, 2))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=16)
privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=data_loader,
    noise_multiplier=1.1,
    max_grad_norm=1.0,
    poisson_sampling=True
)
for i, x in enumerate(data_loader):

    optimizer.zero_grad()
    loss = model(x[0]).sum()
    loss.backward()
    optimizer.step()

Then I get the error RuntimeError: The size of tensor a (10) must match the size of tensor b (13) at non-singleton dimension 0.

My opacus version is 1.1.2. I would like to know if my code style (i.e. freeze first and then call opacus) is not recommended by the opacus community. Sincerely thank you for your reply.

ffuuugor commented 1 year ago

@xtyXpastor Your code is very reasonable and you are using the recommended way to freeze the parameters.

The issue you're experiencing was fixed by #437 and your code should work fine with opacus >= 1.1.3 (i've double checked with the latest version and it works fine).

Feel free to reopen the issue if you're still experiencing issues after updating