pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

Support for is_causal flag for forward hook of layers.DPMultiheadAttention #617

Closed hueck closed 5 months ago

hueck commented 5 months ago

🚀 Feature

In the forward hook of layers.DPMultiheadAttention the following assertion notifies the user that using a causal mask is not yet implemented.

r"""
As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,
since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.
"""
assert (
   is_causal == False
), "We currently do not support causal mask. Will fix it in the future."

This feature would ensure drop in compatibility with the torch MultiHeadAttention module.

Motivation

596 already fixed some of the incompatibilities, however, to the best of my knowledge, the above described gap in implementation is still not filled and prevents full drop in compatibility.

I would very much appreciate information if someone is actively working on this. If not or if help is needed, I would be happy about some pointers on what is necessary to implement this and try to implement it myself.

HuanyuZhang commented 5 months ago

Hi Hueck, thanks for raising this. We do not have any recent plan implementing this feature.

Let me take a quick look to see whether I can provide some code pointer or suggestions to help you implement it.

Just curious, why not just set is_causal = false, and apply your customized attention mask?

HuanyuZhang commented 5 months ago

I would expect a very similar logic as https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py. Just search "causal" for relevant codes.

hueck commented 5 months ago

Hi, thank very much for your help.

Yes, in the meantime I proceeded to just use a custom attention mask, but I was still wondering if there might be any unwanted side effects. I am not too familiar with the details of the implementation, but looking at the _detect_is_causal_mask function and the comments in the transformer module, I think that this should work fine.