pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

Add context manager to toggle on/off privacy in training loop #634

Open bnestor opened 4 months ago

bnestor commented 4 months ago

🚀 Feature

Context manager to optionally disable privacy for mixtures of public and private data.

Motivation

Similar to how torch.cuda.no_grad(), or torch.autocast(enabled=True) work, it would be nice to have a context manager to disable privacy. The main reason is to concurrently train public and private data, without the public data eating away at the privacy budget.

Pitch

# define your components as usual
model = Net()
optimizer = SGD(model.parameters(), lr=0.05)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1024)

# enter PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, data_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=data_loader,
    noise_multiplier=1.1,
    max_grad_norm=1.0,
)

batch = next(iter(dataloader))

# case 1
output = model(batch)
loss = # loss computation
loss.backward() # standard privacy engine with privacy applied

# case 2
with privacy_context(enabled=True):
    output = model(batch)
    loss = # loss computation
    loss.backward() # standard privacy engine with privacy applied

# case 3
with privacy_context(enabled=False):
    output = model(batch)
    loss = # loss computation
    loss.backward() # differential privacy is not applied, and gradient is computed as if privacy engine had not been initialized.

Alternatives

Alternatively, you could have two copies of each model/optimizer/dataloader, and just load the state dict whenever switching from a previous copy to the next. In this case, only one would be initialized through the privacy engine.

Additional context

Based off of current research showing that public pre-training, then private fine-tuning performance increases: https://aclanthology.org/2020.privatenlp-1.5.pdf, https://arxiv.org/pdf/2302.09483.pdf

It would interesting to test if including public data during fine-tuning would improve performance: https://arxiv.org/abs/2111.12292