Closed maksay closed 1 year ago
Hi, the mask dimension is intended to be 4D to support most flexibility. Can you specify why the MultiheadAttention module expects a 2D mask? It is correct that one may commonly use the same mask for all heads, such that extending a 3D tensor to 4D can be integrated into the layer. For simplicity, and since the tutorial does not explicitly use the mask, this is currently not being done. Note that 2D masks are not sufficient to represent common masking schemes (e.g. autoregressive prediction) since every token has different ones masked out.
I agree that the input to MHA could be 3D, ex. to be causal. However, the # of heads is a parameter of MHA module and is not visible outside of it, so it stands to reason that the mask input to MHA would be of a shape that does not include # of heads, but would be converted to 4D inside MHA to respect this parameter, in some way - in other words, independent of what the mask actually represent, there should be some code adding a new dimension, tiling, stacking, or in other way dealing with mask inside of MHA?
If this can help, using attn_logits.masked_fill(mask.unsqueeze(-2),…)
(with unsqueeze
added) makes the code of scaled_dot_product
work for arbitrary arguments (when giving a single mask for each batch of queries), with dimensions as follows:
PS: In order for MultiheadAttention
to also work, you also need to add, in forward()
:
if mask is not None:
# The unsqueeze() is for broadcasting the mask to all heads:
mask = mask.unsqueeze(-2)
Finally got around to add it in ac42535de4cd45957778bbe42892f16f7da63d42. Sorry for the delay!
Tutorial: 16
Describe the bug Passing the masks looks like it's supported in the Transformers tutorial, but it actually doesn't work. The key of the issue is that the
MultiheadAttention
module expects a mask of 2 dimensions (batch_size, seq length) but thescaled_dot_product
expects the mask of the same dimension as logits (batch_size, num_heads, seq_length, seq_length)To Reproduce (if any steps necessary) Steps to reproduce the behavior:
mask = random.bernoulli(m_rng, shape=(3, 16)
, and pass it to the apply fnout, attn = mh_attn.apply({'params': params}, x, mask=mask
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 16), (), (3, 4, 16, 16)]
injnp.where
line ofscaled_dot_product
.Expected behavior The
MultiheadAttention
should transform 2D mask into 4D mask. The following lines in the__call__
function fix the code:Runtime environment (please complete the following information):
I've modified the colab to produce variable length sequences and pass the sequence mask and verified that it works. It's interesting to see that to solve this problem with variable length, two layers are needed: one to estimate the distance to end-of-sequence token, and another one to attend in reverse. Feel free to use it to update the code if it's useful: https://colab.research.google.com/drive/1kDoYuwoFSJ1OqnrFHLwW-zAIkZhEwBNs?usp=sharing