Open netw0rkf10w opened 4 years ago
Instead of MHA, I think the focus here really is on attention functions, which are independent of MHA in general?
@cpuhrsch Yes you are right, if we adopt the second solution that I proposed above, then this is indeed independent of MHA. This is also the reason why I think it is better than the first solution.
π Feature
Making the new MHA implementation even more modular for easy implementation of different attention layers.
Motivation
The new MHA container implementation is already much more flexible than the one in core PyTorch. However, in the current version, when implementing a new attention layer (other than
ScaledDotProduct
), one will have to repeat some code ofScaledDotProduct
, which is not optimal.Computation of the attention weights.
The aggregation of the values based on the computed weights.
Different attention functions may differ only in the first step, or in the second step, or both.
Pitch
I can think of two solutions:
Let the attention layers (e.g.
ScaledDotProduct
) return only the attention weights, then the aggregation of the values is done in the main MHA container.Keep the MHA container unchanged by using a general template class for all the attention layers, and let each specific inherit this class.
I've tried both and found that the second solution is much cleaner. I give below an example in which I re-implemented
ScaledDotProduct
using this approach, and furthermore, I added another attention layer calledGeneralDotProduct
(denoted "general" in Section 3.1 of this paper). (Try adding yourself another attention layer such asGeneralDotProduct
in the current implementation you will see the issue.)@zhangguanheng66 Are you interested in such a PR?