pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.67k stars 331 forks source link

Migrate GradSampler to Tensor hooks #259

Open romovpa opened 2 years ago

romovpa commented 2 years ago

Problem & Motivation

The grad sampler currently relies on nn.Module backward_hook. This approach has a limitation: it only covers parts of the expression where the module is called directly x = module(x). If someone adds an expression based on the module's parameters to the loss, it won't have effect to the gradients.

An example where this problem occurs is adding a parameter regularizer in the loss (see #249), for example:

loss = criterion(y_pred, y_true)
loss += l2_loss(model.parameters())
loss += proximal_loss(model, another_model)   # e.g. encourage two models to have similar weights
loss.backward()

In this case, grad hooks are just not called. When running with PrivacyEngine, backward() silently omits the regularizer term. This is surprising and incorrect behaviour.

Using nn.Module hooks is a fundamental limitation. Migrating to full_backward_hooks doesn't solve the problem.

Pitch

Bare minimum:

Ideally:

Alternatives

Suggestions are welcome.

alexandresablayrolles commented 2 years ago

With Opacus, we throw away the .grad given by Pytorch and recompute it using per-sample gradients. However, in this case, it looks like the .grad would also have signal coming from the L2 loss.

We are discussing with Pytorch a way to disable the computation of .grad in the backward pass, I am wondering if that could do the trick?

tanmay-ty commented 2 years ago

I am facing a similar issue, when using regularization term with my loss function, the gradients vanish, initially I thought the gradients are getting clipped, thus the vanishing gradient problem, but after a lot of adjustments, I figured the regularization term is not working when under privacyengine, while it works well without privacy engine. Do you already have a solution to this issue?

Error: The following layers do not have gradients: ['layers.0.0.weight', 'layers.0.0.bias', 'layers.0.1.weight', 'layers.0.1.bias', 'layers.1.0.weight', 'layers.1.0.bias', 'layers.1.1.weight', 'layers.1.1.bias', 'layers.2.0.weight', 'layers.2.0.bias', 'layers.2.1.weight', 'layers.2.1.bias', 'layers.3.weight', 'layers.3.bias']. Are you sure they were included in the backward pass?

My toy experimental code: loss = loss_fn(logits,y) + 0.1*mse_fn(out,y) optim.zero_grad() loss.backward() optimizer.step()

Error comes on optimizer.step()