long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[14] Longformer: The Long-Document Transformer #14

Open long8v opened 2 years ago

long8v commented 2 years ago
image

paper, code problem : 트랜스포머는 문장의 길이에 quadratic하게 복잡도가 늘어난다. solution : sliding window(+dilated)로 attention을 구하고 이를 stack을 쌓는다. 특정 태스크에 맞는 위치의 token들에 대해 global attention을 추가한다. result : text8, enwik8에서 SOTA, 긴 문서 task인 WikiHop이나 TriviaQA에서 RoBERTa보다 성능이 좋으며 SOTA . 인코더 디코더 모델은 arXiv 요약 데이터셋에서 효과적임을 확인. details :

long8v commented 2 years ago

Longformer

longformer.py

class Longformer(RobertaModel):
    def __init__(self, config):
        super(Longformer, self).__init__(config)
        if config.attention_mode == 'n2':
            pass  # do nothing, use BertSelfAttention instead
        else:
            for i, layer in enumerate(self.encoder.layer):
                layer.attention.self = LongformerSelfAttention(config, layer_id=i)

RobertaModel을 상속해서, encoder layer의 attention을 LongformerSelfAttention를 바꿔준다.

LongformerSelfAttention

long8v commented 2 years ago

sliding_chunks

def _skew(x, direction, padding_value):
    '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
    x_padded = F.pad(x, direction, value=padding_value)
    x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
    return x_padded
def _chunk(x, w):
    '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''

    # non-overlapping chunks of size = 2w
    x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))

    # use `as_strided` to make the chunks overlap with an overlap size = w
    chunk_size = list(x.size())
    chunk_size[1] = chunk_size[1] * 2 - 1

    chunk_stride = list(x.stride())
    chunk_stride[1] = chunk_stride[1] // 2
    return x.as_strided(size=chunk_size, stride=chunk_stride)