pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

Pass `num_groups` arg to ModelValidator.fix #580

Closed ffuuugor closed 1 year ago

ffuuugor commented 1 year ago

Problem

As highlighted by #567, end user have little control over how exactly ModelValidator.fix() deals with BatchNorms. For example, our approach to choosing number of groups is gcd(32, module.num_features), which is fine for most cases, but can break a model occasionally (see #567 for a demonstration)

Solution

Pass num_groups as kwarg, allowing clients to control the behaviour

facebook-github-bot commented 1 year ago

@facebook-github-bot has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 year ago

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

facebook-github-bot commented 1 year ago

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

facebook-github-bot commented 1 year ago

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 year ago

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

facebook-github-bot commented 1 year ago

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

facebook-github-bot commented 1 year ago

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 year ago

@ffuuugor merged this pull request in pytorch/opacus@93dd307e61c3d6ff284441ae80aa3665b97e7937.