Open long8v opened 2 years ago
class Longformer(RobertaModel):
def __init__(self, config):
super(Longformer, self).__init__(config)
if config.attention_mode == 'n2':
pass # do nothing, use BertSelfAttention instead
else:
for i, layer in enumerate(self.encoder.layer):
layer.attention.self = LongformerSelfAttention(config, layer_id=i)
RobertaModel을 상속해서, encoder layer의 attention을 LongformerSelfAttention
를 바꿔준다.
sliding_chunks : https://github.com/allenai/longformer/blob/master/longformer/sliding_chunks.py 에서 따로 정의
class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id):
super(LongformerSelfAttention, self).__init__()
...
assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
assert not self.autoregressive # not supported
assert self.attention_dilation == 1 # dilation is not supported
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
'''
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
-ve: no attention
0: local attention
+ve: global attention
'''
assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and shiould be None"
if attention_mask is not None:
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
key_padding_mask = attention_mask < 0
extra_attention_mask = attention_mask > 0
remove_from_windowed_attention_mask = attention_mask != 0
num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
if max_num_extra_indices_per_batch <= 0:
extra_attention_mask = None
else:
# To support the case of variable number of global attention in the rows of a batch,
# we use the following three selection masks to select global attention embeddings
# in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
# 1) selecting embeddings that correspond to global attention
extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
device=num_extra_indices_per_batch.device)
# mask indicating which values are actually going to be padding
selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
# 2) location of the non-padding values in the selected global attention
selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
# 3) location of the padding values in the selected global attention
selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
else:
remove_from_windowed_attention_mask = None
extra_attention_mask = None
key_padding_mask = None
hidden_states = hidden_states.transpose(0, 1)
seq_len, bsz, embed_dim = hidden_states.size()
assert embed_dim == self.embed_dim
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
q /= math.sqrt(self.head_dim)
q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
# attn_weights = (bsz, seq_len, num_heads, window*2+1)
if self.attention_mode == 'tvm':
q = q.float().contiguous()
k = k.float().contiguous()
attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
elif self.attention_mode == "sliding_chunks":
attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
elif self.attention_mode == "sliding_chunks_no_overlap":
attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
else:
raise False
mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
# cast to float/half then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
float_mask = float_mask.repeat(1, 1, repeat_size, 1)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
if self.attention_mode == 'tvm':
d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
elif self.attention_mode == "sliding_chunks":
d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
elif self.attention_mode == "sliding_chunks_no_overlap":
d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
attn_weights += d_mask
assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
# the extra attention
if extra_attention_mask is not None:
selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
# concat to attn_weights
# (bsz, seq_len, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
attn = 0
if extra_attention_mask is not None:
selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
if self.attention_mode == 'tvm':
v = v.float().contiguous()
attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
elif self.attention_mode == "sliding_chunks":
attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
elif self.attention_mode == "sliding_chunks_no_overlap":
attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
else:
raise False
attn = attn.type_as(hidden_states)
assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()
# For this case, we'll just recompute the attention for these indices
# and overwrite the attn tensor. TODO: remove the redundant computation
if extra_attention_mask is not None:
selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]
q = self.query_global(selected_hidden_states)
k = self.key_global(hidden_states)
v = self.value_global(hidden_states)
q /= math.sqrt(self.head_dim)
q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]
attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
if key_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-10000.0,
)
attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.bmm(attn_probs, v)
assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]
selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)
context_layer = attn.transpose(0, 1)
if output_attentions:
if extra_attention_mask is not None:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch,
# attn_weights are padded with -10000.0 attention scores
attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_weights = attn_weights.permute(0, 2, 1, 3)
outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
return outputs
def _skew(x, direction, padding_value):
'''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
x_padded = F.pad(x, direction, value=padding_value)
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
return x_padded
def _chunk(x, w):
'''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
# non-overlapping chunks of size = 2w
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
# use `as_strided` to make the chunks overlap with an overlap size = w
chunk_size = list(x.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride)
paper, code problem : 트랜스포머는 문장의 길이에 quadratic하게 복잡도가 늘어난다. solution : sliding window(+dilated)로 attention을 구하고 이를 stack을 쌓는다. 특정 태스크에 맞는 위치의 token들에 대해 global attention을 추가한다. result : text8, enwik8에서 SOTA, 긴 문서 task인 WikiHop이나 TriviaQA에서 RoBERTa보다 성능이 좋으며 SOTA . 인코더 디코더 모델은 arXiv 요약 데이터셋에서 효과적임을 확인. details :
windowed local-context self-attention은 문맥적인 표현을 학습하기 위해 사용되고, global attention은 예측을 위해 전체 시퀀스의 표현을 만드는데 사용된다.
auto-regressive 태스크로 평가했을 뿐 아니라, MLM 같은 objective로 학습하고 SOTA임을 확인했다.
encoder-decoder 모델인 LED 모델도 제안한다.
long-document transformers 접근론으로 1) left-to-right 접근법이 있는데, 왼쪽에서 오른쪽으로 움직이면서 chunk로 학습하는 것. 이건 다른 태스크에 적용할때 성능이 불안정함. 2) sparse attention을 하는 접근법이 있는데, Sparse Transformer가 대표적.
긴 문장을 다루는 대표적인 방법은 문서를 최대 토큰 개수인 512로 자르거나, 자른 뒤 결합하는 방법이 있다. 또는 multihop이나 open QA에서 사용되는 방법인데, 먼저 관련있는 문서를 retrieve하고 그 뒤에 answer extraction을 위해 전달하는 방법이다.
Attention Pattern
[CLS]
토큰, QA에서는 question + docuemnt concat한 형태 등) 위의 어텐션들이 다양한 task에 알맞는 표현이 아니기 때문에 특정 태스크에 맞는 위치에 있는 token들에 대해 global attention을 추가하였다.