phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

Masking in transformer tutorial #67

Closed maksay closed 1 year ago

maksay commented 1 year ago

Tutorial: 16

Describe the bug Passing the masks looks like it's supported in the Transformers tutorial, but it actually doesn't work. The key of the issue is that the MultiheadAttention module expects a mask of 2 dimensions (batch_size, seq length) but the scaled_dot_product expects the mask of the same dimension as logits (batch_size, num_heads, seq_length, seq_length)

To Reproduce (if any steps necessary) Steps to reproduce the behavior:

  1. Go to cell '## Test MultiheadAttention implementation'
  2. Add a line to add a sequence mask, mask = random.bernoulli(m_rng, shape=(3, 16), and pass it to the apply fn out, attn = mh_attn.apply({'params': params}, x, mask=mask
  3. Run the cell to see the error: ValueError: Incompatible shapes for broadcasting: shapes=[(3, 16), (), (3, 4, 16, 16)] in jnp.where line of scaled_dot_product.

Expected behavior The MultiheadAttention should transform 2D mask into 4D mask. The following lines in the __call__ function fix the code:

 if mask is not None:
          mask = jnp.stack([mask] * self.num_heads, axis=-1)
          mask = jnp.stack([mask] * seq_length,axis=-1)
          mask = mask.transpose(0, 2, 1, 3)
          mask *= mask.transpose(0, 1, 3, 2)

Runtime environment (please complete the following information):

I've modified the colab to produce variable length sequences and pass the sequence mask and verified that it works. It's interesting to see that to solve this problem with variable length, two layers are needed: one to estimate the distance to end-of-sequence token, and another one to attend in reverse. Feel free to use it to update the code if it's useful: https://colab.research.google.com/drive/1kDoYuwoFSJ1OqnrFHLwW-zAIkZhEwBNs?usp=sharing

phlippe commented 1 year ago

Hi, the mask dimension is intended to be 4D to support most flexibility. Can you specify why the MultiheadAttention module expects a 2D mask? It is correct that one may commonly use the same mask for all heads, such that extending a 3D tensor to 4D can be integrated into the layer. For simplicity, and since the tutorial does not explicitly use the mask, this is currently not being done. Note that 2D masks are not sufficient to represent common masking schemes (e.g. autoregressive prediction) since every token has different ones masked out.

maksay commented 1 year ago

I agree that the input to MHA could be 3D, ex. to be causal. However, the # of heads is a parameter of MHA module and is not visible outside of it, so it stands to reason that the mask input to MHA would be of a shape that does not include # of heads, but would be converted to 4D inside MHA to respect this parameter, in some way - in other words, independent of what the mask actually represent, there should be some code adding a new dimension, tiling, stacking, or in other way dealing with mask inside of MHA?

lebigot commented 1 year ago

If this can help, using attn_logits.masked_fill(mask.unsqueeze(-2),…) (with unsqueeze added) makes the code of scaled_dot_product work for arbitrary arguments (when giving a single mask for each batch of queries), with dimensions as follows:

lebigot commented 1 year ago

PS: In order for MultiheadAttention to also work, you also need to add, in forward():

        if mask is not None:
            # The unsqueeze() is for broadcasting the mask to all heads:
            mask = mask.unsqueeze(-2)
phlippe commented 1 year ago

Finally got around to add it in ac42535de4cd45957778bbe42892f16f7da63d42. Sorry for the delay!