pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.67k stars 332 forks source link

Make DPMultiheadAttention drop-in compatible with nn.MultiheadAttention #529

Closed Wei-1 closed 1 year ago

Wei-1 commented 1 year ago

Summary: This PR is target to resolve #123 on GitHub by having an additional re-naming mechanism to match the state_dict structure of nn.MultiheadAttention.

Differential Revision: D40671870

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

Wei-1 commented 1 year ago

You can test the edited part by running pytest within the opacus folder. And you should see the following expected result:

============= 162 passed, 41 skipped, 4411 warnings in 243.18s (0:04:03) =============
facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

Wei-1 commented 1 year ago

On the other hand, I can potentially see issues with unexpected state_dict: keys not matching parameter names. That said, I don't see any immediate problems, but I might be missing something.

As we try to cover the parameter naming logic in nn.MultiheadAttention and DPMultiheadAttention, I think the major problem might come with maintenance. Since the entire transformation logic is rule-based, things will most likely break when there are modifications in nn.MultiheadAttention or DPMultiheadAttention in their naming method.

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

Wei-1 commented 1 year ago

@ffuuugor Changes had been made to address the concerns. Please let me know if this make sense to you!

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

ffuuugor commented 1 year ago

Thanks for addressing the comments! I believe this PR is close to landing, but we need to sort out one thing first. Due to some bug in how CircleCI works with phabricator, our main testing pipeline is not being triggered on this PR.

I can see that it won't pass the linter check. Please refer to our Contributor's guide and run isort/black/flake commands to check your code it formatted properly

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

Wei-1 commented 1 year ago

Thanks, @ffuuugor! An update had been made to address the lint consistency with Black/ISort/Flake8.

ffuuugor commented 1 year ago

Hey Thanks for taking care of this. One last thing - sometimes isort give different recommendations depending on the version. I have mine set up exactly as CircleCI and it gives the following:

--- a/opacus/layers/dp_multihead_attention.py
+++ b/opacus/layers/dp_multihead_attention.py
@@ -14,14 +14,13 @@
 # limitations under the License.

 import warnings
+from collections import OrderedDict

 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.nn.parameter import Parameter

-from collections import OrderedDict
-

 class SequenceBias(nn.Module):

Can you pls make the change to make the linter happy?

And I'm really sorry for back and forth on this. Tests not triggering for fbcode-exported PRs is painful and we're investigating.

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D40671870

Wei-1 commented 1 year ago

I just pushed a new version to address this issue! Please let me know if everything is good!