huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.61k stars 26.19k forks source link

sdpa for bert should support 4D attention mask. #31036

Open Leoyzen opened 3 months ago

Leoyzen commented 3 months ago

Feature request

Currently this is only possible with 2D-mask when sdap is enabled.

# modeling bert
        # Expand the attention mask
        if use_sdpa_attention_masks:
            # Expand the attention mask for SDPA.
            # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
            if self.config.is_decoder:
                extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                    attention_mask,
                    input_shape,
                    embedding_output,
                    past_key_values_length,
                )
            else:
                extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    attention_mask, embedding_output.dtype, tgt_len=seq_length
                )
            #  print(f"sdpa attn mask:{embedding_output.dtype=} {extended_attention_mask.sum(axis=(2,3))=}")
        else:
            # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
            # ourselves in which case we just need to make it broadcastable to all heads.
            extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

Motivation

4D attention mask is widely used such as padding bert.

We construct 4D attention mask before the model forward and pass it to the forward call.

I think it is easy to exend the current implement of _expand_mask to support 3D/4D attention mask.

    @staticmethod
    def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
        """
        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
        """
        bsz, *_, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len

        if mask.ndim == 2:
            expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        elif mask.ndim == 3:
            expanded_mask = mask.unsqueeze(1).to(dtype)
        elif mask.ndim == 4:
            expanded_mask = mask.to(dtype)

        inverted_mask = 1.0 - expanded_mask
        #  print(f"{dtype=} {torch.finfo(dtype).min=} {expanded_mask.bool().sum((1,2,3))=}")

        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

Your contribution

I can do PR and contribution.

amyeroberts commented 3 months ago

cc @ArthurZucker @fxmarty

ArthurZucker commented 3 months ago

Sounds good! Feel free to open a PR for a fix!