lucidrains / h-transformer-1d

Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning
MIT License
154 stars 21 forks source link

add key_value_mask argument for cross attention #19

Closed jglaser closed 2 years ago

jglaser commented 2 years ago

This change set enables the use of two different masks for queries and keys/values. This is necessary for using the module as a cross attention.

lucidrains commented 2 years ago

@jglaser hello there! looks good :) have you had much luck with this efficient attention variant?

lucidrains commented 2 years ago

hmm actually, how would cross attention work if the tokens and the cross-attended contexts have different dimensions? i thought this scheme requires a square matrix

jglaser commented 2 years ago

@jglaser hello there! looks good :) have you had much luck with this efficient attention variant?

I am still testing if it results in better model performance for long sentences, but so far it looks good in terms of memory efficiency

jglaser commented 2 years ago

hmm actually, how would cross attention work if the tokens and the cross-attended contexts have different dimensions? i thought this scheme requires a square matrix

You're right, I have only tested this for a square matrix. Both in sentence length, and in hidden dimension.

I achieve the former by extending the shorter of the two input sentences with the pad token, and for the latter I implement a linear layer that translates between keys and values of different hidden dimensions (per token).

Both can be done either in the forward function outside the self attention module, or by replacing the to_qkv and to_out member variables with custom layers, without any further changes to this module.

That said, we may eventually want to generalize this self-attention to accept inputs for the queries and keys/values of different dimensions.

jglaser commented 2 years ago

Closing this: there is no need for a separate mask for queries and keys/values. Queries can always be masked by truncating the output of the self attention layer appropriately, only a single mask (for keys and values) is needed.