pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch
https://pytorch.org/text
BSD 3-Clause "New" or "Revised" License
3.51k stars 811 forks source link

Make MHA even more flexible #884

Open netw0rkf10w opened 4 years ago

netw0rkf10w commented 4 years ago

πŸš€ Feature

Making the new MHA implementation even more modular for easy implementation of different attention layers.

Motivation

The new MHA container implementation is already much more flexible than the one in core PyTorch. However, in the current version, when implementing a new attention layer (other than ScaledDotProduct), one will have to repeat some code of ScaledDotProduct, which is not optimal.

Different attention functions may differ only in the first step, or in the second step, or both.

Pitch

I can think of two solutions:

  1. Let the attention layers (e.g. ScaledDotProduct) return only the attention weights, then the aggregation of the values is done in the main MHA container.

  2. Keep the MHA container unchanged by using a general template class for all the attention layers, and let each specific inherit this class.

I've tried both and found that the second solution is much cleaner. I give below an example in which I re-implemented ScaledDotProduct using this approach, and furthermore, I added another attention layer called GeneralDotProduct (denoted "general" in Section 3.1 of this paper). (Try adding yourself another attention layer such as GeneralDotProduct in the current implementation you will see the issue.)

class GeneralAttention(torch.nn.Module):

    def __init__(self, dropout=0.):
        r"""General template for attention layers.

        Args:
            dropout (float): probability of dropping an attention weight.
        """
        super().__init__()
        self.dropout = dropout

    def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
                        attn_mask: Optional[torch.Tensor] = None,
                        ) -> torch.Tensor:
        raise NotImplementedError

    def compute_outputs(self, value: torch.Tensor,
                        weights: torch.Tensor,
                        ) -> torch.Tensor:
        r"""Computing the attention outputs from value and attention weights

        Args:
            query (Tensor): Projected query
            weights (Tensor): Attention weights 
        Shape:
            - value: :math:`(S, N * H, E / H)`
            - weights: :math:`(N * H, L, S)`

            - Output: :math:`(L, N * H, E / H)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension.
        """
        # Transpose: (S, N*H, E/H) --> (N*H, S, E/H)
        value = value.transpose(-2, -3)
        # (N*H, L, S) times (N*H, S, E/H) --> (N*H, L, E/H)
        attn_output = torch.matmul(weights, value)
        # Back to (L, N*H, E/H)
        attn_output = attn_output.transpose(-2, -3)

        return attn_output

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                bias_k: Optional[torch.Tensor] = None,
                bias_v: Optional[torch.Tensor] = None,
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Attention forward pass.

        Args:
            query (Tensor): Projected query
            key (Tensor): Projected key
            value (Tensor): Projected value
            attn_mask (BoolTensor, optional): 3D mask that prevents attention
                to certain positions.
            bias_k and bias_v: (Tensor, optional): one more key and value
                sequence to be added at sequence dim (dim=-3). Those are used
                for incremental decoding. Users should provide non-None to both
                arguments in order to activate them.
        Shape:
            - query: :math:`(L, N * H, E / H)`
            - key: :math:`(S, N * H, E / H)`
            - value: :math:`(S, N * H, E / H)`
            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
                allowed to attend while ``False`` values will be unchanged.
            - bias_k and bias_v:bias: :math:`(1, N * H, E / H)`

            - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension.
        """
        if bias_k is not None and bias_v is not None:
            assert (key.size(-1) == bias_k.size(-1) and
                    key.size(-2) == bias_k.size(-2) and
                    bias_k.size(-3) == 1), "Shape of bias_k is not supported"
            assert (value.size(-1) == bias_v.size(-1) and
                    value.size(-2) == bias_v.size(-2) and
                    bias_v.size(-3) == 1), "Shape of bias_v is not supported"
            key = torch.cat([key, bias_k])
            value = torch.cat([value, bias_v])
            if attn_mask is not None:
                attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))

        # Compute attention weights
        attn_weights = self.compute_weights(query, key, attn_mask=attn_mask)
        # Add dropout
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        # Then compute the attention outputs
        attn_output = self.compute_outputs(value, weights=attn_weights)

        return attn_output, attn_weights

class ScaledDotProduct(GeneralAttention):
    r"""Processes a projected query and key-value pair to apply
        scaled dot product attention.

        Examples::
            >>> SDP = torchtext.modules.ScaledDotProduct(dropout=0.1)
            >>> q = torch.randn(256, 21, 3)
            >>> k = v = torch.randn(256, 21, 3)
            >>> attn_output, attn_weights = SDP(q, k, v)
            >>> print(attn_output.shape, attn_weights.shape)
            torch.Size([256, 21, 3]) torch.Size([256, 21, 21])
        """

    def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
                        attn_mask: Optional[torch.Tensor] = None
                        ) -> torch.Tensor:
        r"""Uses a scaled dot product with the projected key-value pair to 
        compute the attention weights.

        Args:
            query (Tensor): Projected query
            key (Tensor): Projected key
            attn_mask (BoolTensor, optional): 3D mask that prevents attention
                to certain positions.
        Shape:
            - query: :math:`(L, N * H, E / H)`
            - key: :math:`(S, N * H, E / H)`
            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
                allowed to attend while ``False`` values will be unchanged.

            - Output: :math:`(N * H, L, S)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension.
        """
        tgt_len, head_dim = query.size(-3), query.size(-1)
        assert query.size(-1) == key.size(-1), "Feature dims of query and key must equal."
        src_len = key.size(-3)
        batch_heads = max(query.size(-2), key.size(-2))

        # Scale query
        query, key = query.transpose(-2, -3), key.transpose(-2, -3)
        query = query * (float(head_dim) ** -0.5)

        # Attention weights: dot product of q, k
        attn_weights = torch.matmul(query, key.transpose(-2, -1))
        if attn_mask is not None:
            if attn_mask.dim() != 3:
                raise RuntimeError('attn_mask must be a 3D tensor.')
            if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
               (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
                raise RuntimeError('The size of the attn_mask is not correct.')
            if attn_mask.dtype != torch.bool:
                raise RuntimeError('Only bool tensor is supported for attn_mask')

            attn_weights.masked_fill_(attn_mask, -1e8,)
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

        return attn_weights

class GeneralDotProduct(GeneralAttention):

    def __init__(self, embed_dim, kdim=None, dropout=0.):
        r"""Processes a projected query and key-value pair to apply the general
        vector-matrix-vector product attention.

        Examples::
            >>> embed_dim, kdim = 5, 3
            >>> GDP = torchtext.modules.GeneralDotProduct(embed_dim, kdim=kdim, dropout=0.1)
            >>> q = torch.randn(256, 21, embed_dim)
            >>> k = v = torch.randn(256, 12, kdim)
            >>> attn_output, attn_weights = GDP(q, k, v)
            >>> print(attn_output.shape, attn_weights.shape)
            torch.Size([256, 21, 3]) torch.Size([256, 21, 12])

        Args:
            dropout (float): probability of dropping an attention weight.
        """
        super().__init__(dropout=dropout)
        kdim = embed_dim if kdim is None else kdim
        self.W = torch.nn.Parameter(torch.empty(embed_dim, kdim))

    def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
                        attn_mask: Optional[torch.Tensor] = None
                        ) -> torch.Tensor:
        r"""Uses a scaled dot product with the projected key-value pair to update
        the projected query.

        Args:
            query (Tensor): Projected query
            key (Tensor): Projected key
            attn_mask (BoolTensor, optional): 3D mask that prevents attention
                to certain positions.
        Shape:
            - query: :math:`(L, N * H, E / H)`
            - key: :math:`(S, N * H, K / H)`
            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
                allowed to attend while ``False`` values will be unchanged.

            - Output: :math:`(N * H, L, S)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension,
            K is the key dimension.
        """
        tgt_len, head_dim = query.size(-3), query.size(-1)
        assert (query.size(-1) == self.W.shape[0] and
                key.size(-1) == self.W.shape[1]), "Feature dims not match."
        src_len = key.size(-3)
        batch_heads = max(query.size(-2), key.size(-2))

        # (L, N * H, E/H) --> (N * H, L, E/H), (S, N * H, K/H) --> (N * H, S, K/H)
        query, key = query.transpose(-2, -3), key.transpose(-2, -3)

        # Attention weights: dot product of q, k
        # W is (E/H, K/H)
        attn_weights = torch.matmul(query, self.W)
        print(f'attn_weights = {attn_weights.shape}, key = {key.shape}')
        attn_weights = torch.matmul(attn_weights, key.transpose(-2, -1))

        if attn_mask is not None:
            if attn_mask.dim() != 3:
                raise RuntimeError('attn_mask must be a 3D tensor.')
            if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
               (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
                raise RuntimeError('The size of the attn_mask is not correct.')
            if attn_mask.dtype != torch.bool:
                raise RuntimeError('Only bool tensor is supported for attn_mask')

            attn_weights.masked_fill_(attn_mask, -1e8,)

        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

        return attn_weights

@zhangguanheng66 Are you interested in such a PR?

cpuhrsch commented 4 years ago

Instead of MHA, I think the focus here really is on attention functions, which are independent of MHA in general?

netw0rkf10w commented 4 years ago

@cpuhrsch Yes you are right, if we adopt the second solution that I proposed above, then this is indeed independent of MHA. This is also the reason why I think it is better than the first solution.