idiap / fast-transformers

Pytorch library for fast transformer implementations
1.63k stars 176 forks source link

Feature request for relative position encoding #53

Open hadaev8 opened 3 years ago

hadaev8 commented 3 years ago

Seems like my encoder decoder model fail in inference then it need to produce sample of unseen length.

apoorv2904 commented 3 years ago

Hi,

As discussed in #30 , attention is never explicitly computed in any linear model as a result of which it is not possible to use a simple relative positional encoding.

With that stated, you may want to check out Transformers with convolutional context for ASR where the authors have proposed to use convolutional layers as front-end to mimic the effect of relative positional embeddings.

Thanks, Apoorv

hadaev8 commented 3 years ago

@apoorv2904 Should be still useful for full attention. Yes, rnn and conv layers seem to add positional information. Still, this paper claims positional information from rnn layer and relative position encoding both beneficial. https://www.aclweb.org/anthology/K19-1031.pdf This method seems to be the best relative position embedding. https://arxiv.org/abs/2009.13658 Also, where is the implementation of this method in huggingface, but not sure how to add it to recurrent attention.

hadaev8 commented 3 years ago

So I guess it should be something like this. I will be glad if someone can confirm this implementation.

class FullAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """

    def __init__(self, attention_head_size, max_position_embeddings=128,
                 softmax_temp=None, attention_dropout=0.1, event_dispatcher=""):
        super(FullAttention, self).__init__()
        self.softmax_temp = softmax_temp
        self.dropout = nn.Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)
        self.max_position_embeddings = max_position_embeddings
        self.distance_embedding = nn.Embedding(
            2 * max_position_embeddings + 1, attention_head_size)

    def forward(self, queries, keys, values, attn_mask, query_lengths,
                key_lengths):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            queries: (N, L, H, E) The tensor containing the queries
            keys: (N, S, H, E) The tensor containing the keys
            values: (N, S, H, D) The tensor containing the values
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            query_lengths: An implementation of  BaseMask that encodes how
                           many queries each sequence in the batch consists of
            key_lengths: An implementation of BaseMask that encodes how
                         many queries each sequence in the batch consists of
        """
        # Extract some shapes and compute the temperature
        N, L, H, E = queries.shape
        _, S, _, D = values.shape
        softmax_temp = self.softmax_temp or 1. / math.sqrt(E)

        # Compute the unnormalized attention and apply the masks
        QK = torch.einsum("nlhe,nshe->nhls", queries, keys)

        position_ids_l = torch.arange(
            L, dtype=torch.long, device=queries.device).view(-1, 1)
        position_ids_r = torch.arange(
            L, dtype=torch.long, device=queries.device).view(1, -1)

        distance = (position_ids_l - position_ids_r).clip(-self.max_position_embeddings, self.max_position_embeddings)
        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings)

        relative_position_scores_query = torch.einsum(
            "blhd,lrd->bhlr", queries, positional_embedding)
        relative_position_scores_key = torch.einsum(
            "brhd,lrd->bhlr", keys, positional_embedding)
        QK = QK + relative_position_scores_query + relative_position_scores_key

        if not attn_mask.all_ones:
            QK = QK + attn_mask.additive_matrix
        QK = QK + key_lengths.additive_matrix[:, None, None]

        # Compute the attention and the weighted average
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
        V = torch.einsum("nhls,nshd->nlhd", A, values)

        # Let the world know of the attention matrix
        self.event_dispatcher.dispatch(AttentionEvent(self, A))

        # Make sure that what we return is contiguous
        return V.contiguous()
hadaev8 commented 3 years ago

Recurrent attention

class RecurrentFullAttention(nn.Module):
    """Implement the full softmax attention as a recurrent module.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """

    def __init__(self, attention_head_size, max_position_embeddings=128,
                 softmax_temp=None, attention_dropout=0.1, event_dispatcher=""):
        super(RecurrentFullAttention, self).__init__()
        self.softmax_temp = softmax_temp
        self.dropout = nn.Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)
        self.max_position_embeddings = max_position_embeddings
        self.distance_embedding = nn.Embedding(
            2 * max_position_embeddings + 1, attention_head_size)

    def forward(self, query, key, value, state=None, step=None):
        # Extract some shapes and compute the temperature
        N, H, E = query.shape
        _, _, D = value.shape
        softmax_temp = self.softmax_temp or 1. / math.sqrt(E)

        # Aggregate the list of keys and values
        if state is not None:
            keys, values = state
            keys = torch.cat([keys, key[:, :, None]], dim=2)
            values = torch.cat([values, value[:, :, None]], dim=2)
        else:
            keys = key[:, :, None]
            values = value[:, :, None]

        if step is None:
            step = -1
        step += 1

        # Compute the unnormalized attention
        QK = torch.einsum("nhd,nhsd->nhs", query, keys)

        position_ids_l = torch.tensor(
            values.shape[2] - 1, dtype=torch.long, device=query.device)
        position_ids_r = torch.arange(
            values.shape[2], dtype=torch.long, device=query.device)

        distance = (position_ids_l - position_ids_r).clip(-self.max_position_embeddings,
                                                          self.max_position_embeddings)
        positional_embedding = self.distance_embedding(
            distance + self.max_position_embeddings)

        relative_position_scores_query = torch.einsum(
            "nhd,sd->nhs", query, positional_embedding)
        relative_position_scores_key = torch.einsum(
            "nhsd,sd->nhs", keys, positional_embedding)

        QK = QK + relative_position_scores_query + relative_position_scores_key

        # Compute the attention and the weighted average
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
        V = torch.einsum("nhs,nhsd->nhd", A, values).contiguous()

        # Make sure that what we return is contiguous
        return V, [keys, values]

class RecurrentCrossFullAttention(nn.Module):
    """Implement autoregressive softmax cross attention as a recurrent
    module.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
        event_dispatcher: str or EventDispatcher instance to be used by this
                          module for dispatching events (default: the default
                          global dispatcher)
    """

    def __init__(self, attention_head_size, max_position_embeddings=128,
                 softmax_temp=None, attention_dropout=0.1, event_dispatcher=""):
        super(RecurrentCrossFullAttention, self).__init__()
        self.softmax_temp = softmax_temp
        self.dropout = nn.Dropout(attention_dropout)
        self.event_dispatcher = EventDispatcher.get(event_dispatcher)
        self.max_position_embeddings = max_position_embeddings
        self.distance_embedding = nn.Embedding(
            2 * max_position_embeddings + 1, attention_head_size)

    def forward(self, query, keys, values, step, key_lengths, state=None):
        # Extract some shapes and compute the temperature
        N, H, E = query.shape
        softmax_temp = self.softmax_temp or 1. / math.sqrt(E)

        # Extract the keys and values either from the arguments or the state
        if state is not None:
            keys, values = state

        # Compute the unnormalized attention and apply the key length mask
        QK = torch.einsum("nhe,nshe->nsh", query, keys)

        position_ids_l = torch.tensor(
            step, dtype=torch.long, device=query.device)
        position_ids_r = torch.arange(
            values.shape[1], dtype=torch.long, device=query.device)

        distance = (position_ids_l - position_ids_r).clip(-self.max_position_embeddings,
                                                          self.max_position_embeddings)
        positional_embedding = self.distance_embedding(
            distance + self.max_position_embeddings)

        relative_position_scores_query = torch.einsum(
            "nhd,sd->nsh", query, positional_embedding)
        relative_position_scores_key = torch.einsum(
            "nshd,sd->nsh", keys, positional_embedding)

        QK = QK + relative_position_scores_query + relative_position_scores_key

        QK = QK + key_lengths.additive_matrix[:, :, None]

        # Compute the attention and the weighted average
        A = self.dropout(torch.softmax(softmax_temp * QK, dim=1))
        V = torch.einsum("nsh,nshd->nhd", A, values)

        # Make sure that we return a contiguous value
        return V.contiguous(), [keys, values]
hadaev8 commented 3 years ago

@angeloskath Should you take a look?

shi27feng commented 3 years ago

@hadaev8 Thank you for the implementation, at least your code could be a reference for my project. Best Feng,

imj2185 commented 3 years ago

I have some questions about your FullAttention class.

Did you assume that the number of channels would always be split into the size of the number of heads? For example, if my channel is 32 and num of heads is 8, then h, d will be 8, 4 respectively. Then when I compute the relative position scores, there is a dimension mismatch error since self.distance_embedding is defined with the size of the number of heads.. (attention_head_size)

relative_position_scores_query = torch.einsum( "blhd,lrd->bhlr", queries, positional_embedding)

self.distance_embedding = nn.Embedding( 2 * max_position_embeddings + 1, attention_head_size)

Would you please clarify?

Thank you.

hadaev8 commented 3 years ago

@imj2185 By attention_head_size I assumed d eg 4. I too found it misleading and renamed it to query_dimensions in mine code.

hadaev8 commented 3 years ago

@imj2185 Cross attention should not have rel pos. Just in case.