pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.68k stars 334 forks source link

Utility to test the correctness of per sample gradients #484

Open ffuuugor opened 2 years ago

ffuuugor commented 2 years ago

For custom grad_samplers or functorch-based implementation

ffuuugor commented 1 year ago

More context.

The first and essential step of DP-SGD implementation is computing per-sample gradients. While Opacus provides various approaches to do so (hooks, ExpandedWeights, functorch) and performs some level of model validation, it is still possible to end up with wrong per-sample gradients which Opacus will not notify you of.

There are multiple potential scenarios why it's possible, but one stands out as the most important. Opacus has an implicit requirement that any input model always maintain batch dimension as the first dimension throughout the forward() method (or second if batch_first=False). Which means simple x.transpose() can break all privacy guarantees. Detecting such behaviour would require deep introspection of the model and we haven't found a generic enough way to do this for arbitrary input models.

There's, however, one approach that we already use to verify correctness of the per-sample computation: compare against microbatches. By definition, "per-sample gradient" is a gradient you would get if a given sample was the only one in a batch. It is, however, quite easy to split the batch into N micro-batches of size 1 and use these gradients to verify p.grad_sample.

Here's how we do it for our unit test: link

The idea of this new feature is to expose similar utility to an end user to let them check that:

aaossa commented 1 year ago

Hi, any updates on #532 ? (Just lurking but it seems that the PR was left behind for some reason)

karthikprasad commented 1 year ago

@aaossa , looks like the PR has been approved and is pending merge. @psolikov, would you be able to rebase and merge it?

psolikov commented 1 year ago

@aaossa @karthikprasad Thanks for reminding! I'll try to merge soon.