Open hadaev8 opened 3 years ago
Hi,
As discussed in #30 , attention is never explicitly computed in any linear model as a result of which it is not possible to use a simple relative positional encoding.
With that stated, you may want to check out Transformers with convolutional context for ASR where the authors have proposed to use convolutional layers as front-end to mimic the effect of relative positional embeddings.
Thanks, Apoorv
@apoorv2904 Should be still useful for full attention. Yes, rnn and conv layers seem to add positional information. Still, this paper claims positional information from rnn layer and relative position encoding both beneficial. https://www.aclweb.org/anthology/K19-1031.pdf This method seems to be the best relative position embedding. https://arxiv.org/abs/2009.13658 Also, where is the implementation of this method in huggingface, but not sure how to add it to recurrent attention.
So I guess it should be something like this. I will be glad if someone can confirm this implementation.
class FullAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, attention_head_size, max_position_embeddings=128,
softmax_temp=None, attention_dropout=0.1, event_dispatcher=""):
super(FullAttention, self).__init__()
self.softmax_temp = softmax_temp
self.dropout = nn.Dropout(attention_dropout)
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * max_position_embeddings + 1, attention_head_size)
def forward(self, queries, keys, values, attn_mask, query_lengths,
key_lengths):
"""Implements the multihead softmax attention.
Arguments
---------
queries: (N, L, H, E) The tensor containing the queries
keys: (N, S, H, E) The tensor containing the keys
values: (N, S, H, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
query_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
"""
# Extract some shapes and compute the temperature
N, L, H, E = queries.shape
_, S, _, D = values.shape
softmax_temp = self.softmax_temp or 1. / math.sqrt(E)
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhe,nshe->nhls", queries, keys)
position_ids_l = torch.arange(
L, dtype=torch.long, device=queries.device).view(-1, 1)
position_ids_r = torch.arange(
L, dtype=torch.long, device=queries.device).view(1, -1)
distance = (position_ids_l - position_ids_r).clip(-self.max_position_embeddings, self.max_position_embeddings)
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings)
relative_position_scores_query = torch.einsum(
"blhd,lrd->bhlr", queries, positional_embedding)
relative_position_scores_key = torch.einsum(
"brhd,lrd->bhlr", keys, positional_embedding)
QK = QK + relative_position_scores_query + relative_position_scores_key
if not attn_mask.all_ones:
QK = QK + attn_mask.additive_matrix
QK = QK + key_lengths.additive_matrix[:, None, None]
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
V = torch.einsum("nhls,nshd->nlhd", A, values)
# Let the world know of the attention matrix
self.event_dispatcher.dispatch(AttentionEvent(self, A))
# Make sure that what we return is contiguous
return V.contiguous()
Recurrent attention
class RecurrentFullAttention(nn.Module):
"""Implement the full softmax attention as a recurrent module.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, attention_head_size, max_position_embeddings=128,
softmax_temp=None, attention_dropout=0.1, event_dispatcher=""):
super(RecurrentFullAttention, self).__init__()
self.softmax_temp = softmax_temp
self.dropout = nn.Dropout(attention_dropout)
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * max_position_embeddings + 1, attention_head_size)
def forward(self, query, key, value, state=None, step=None):
# Extract some shapes and compute the temperature
N, H, E = query.shape
_, _, D = value.shape
softmax_temp = self.softmax_temp or 1. / math.sqrt(E)
# Aggregate the list of keys and values
if state is not None:
keys, values = state
keys = torch.cat([keys, key[:, :, None]], dim=2)
values = torch.cat([values, value[:, :, None]], dim=2)
else:
keys = key[:, :, None]
values = value[:, :, None]
if step is None:
step = -1
step += 1
# Compute the unnormalized attention
QK = torch.einsum("nhd,nhsd->nhs", query, keys)
position_ids_l = torch.tensor(
values.shape[2] - 1, dtype=torch.long, device=query.device)
position_ids_r = torch.arange(
values.shape[2], dtype=torch.long, device=query.device)
distance = (position_ids_l - position_ids_r).clip(-self.max_position_embeddings,
self.max_position_embeddings)
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings)
relative_position_scores_query = torch.einsum(
"nhd,sd->nhs", query, positional_embedding)
relative_position_scores_key = torch.einsum(
"nhsd,sd->nhs", keys, positional_embedding)
QK = QK + relative_position_scores_query + relative_position_scores_key
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1))
V = torch.einsum("nhs,nhsd->nhd", A, values).contiguous()
# Make sure that what we return is contiguous
return V, [keys, values]
class RecurrentCrossFullAttention(nn.Module):
"""Implement autoregressive softmax cross attention as a recurrent
module.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, attention_head_size, max_position_embeddings=128,
softmax_temp=None, attention_dropout=0.1, event_dispatcher=""):
super(RecurrentCrossFullAttention, self).__init__()
self.softmax_temp = softmax_temp
self.dropout = nn.Dropout(attention_dropout)
self.event_dispatcher = EventDispatcher.get(event_dispatcher)
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * max_position_embeddings + 1, attention_head_size)
def forward(self, query, keys, values, step, key_lengths, state=None):
# Extract some shapes and compute the temperature
N, H, E = query.shape
softmax_temp = self.softmax_temp or 1. / math.sqrt(E)
# Extract the keys and values either from the arguments or the state
if state is not None:
keys, values = state
# Compute the unnormalized attention and apply the key length mask
QK = torch.einsum("nhe,nshe->nsh", query, keys)
position_ids_l = torch.tensor(
step, dtype=torch.long, device=query.device)
position_ids_r = torch.arange(
values.shape[1], dtype=torch.long, device=query.device)
distance = (position_ids_l - position_ids_r).clip(-self.max_position_embeddings,
self.max_position_embeddings)
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings)
relative_position_scores_query = torch.einsum(
"nhd,sd->nsh", query, positional_embedding)
relative_position_scores_key = torch.einsum(
"nshd,sd->nsh", keys, positional_embedding)
QK = QK + relative_position_scores_query + relative_position_scores_key
QK = QK + key_lengths.additive_matrix[:, :, None]
# Compute the attention and the weighted average
A = self.dropout(torch.softmax(softmax_temp * QK, dim=1))
V = torch.einsum("nsh,nshd->nhd", A, values)
# Make sure that we return a contiguous value
return V.contiguous(), [keys, values]
@angeloskath Should you take a look?
@hadaev8 Thank you for the implementation, at least your code could be a reference for my project. Best Feng,
I have some questions about your FullAttention class.
Did you assume that the number of channels would always be split into the size of the number of heads? For example, if my channel is 32 and num of heads is 8, then h, d will be 8, 4 respectively. Then when I compute the relative position scores, there is a dimension mismatch error since self.distance_embedding is defined with the size of the number of heads.. (attention_head_size)
relative_position_scores_query = torch.einsum( "blhd,lrd->bhlr", queries, positional_embedding)
self.distance_embedding = nn.Embedding( 2 * max_position_embeddings + 1, attention_head_size)
Would you please clarify?
Thank you.
@imj2185 By attention_head_size I assumed d eg 4. I too found it misleading and renamed it to query_dimensions in mine code.
@imj2185 Cross attention should not have rel pos. Just in case.
Seems like my encoder decoder model fail in inference then it need to produce sample of unseen length.