ccdv-ai / convert_checkpoint_to_lsg

Efficient Attention for Long Sequence Processing
MIT License
84 stars 11 forks source link
deep-learning natural-language-processing nlp python pytorch seq2seq transformer transformers

LSG Attention: Extrapolation of pretrained Transformers to long sequences

ArXiv paper \ Accepted as a conference paper in PAKDD 2023.

Requires transformers >= 4.36.1

Optional package to convert models:

pip install lsg-converter

Compatible models

Note that for non-compatible models, the LSG attention layer is available, see LSG-Layer and lsg_converter/attention_layers for more informations.

This package can convert various HuggingFace checkpoints and architectures to their LSG (Local-Sparse-Global) variant to handle long sequences. \ Available model types:

Some converted checkpoints are available here.

attn

Efficiency

Memory and training speed for a binary classification task with 4x4096 tokens batches and Adam (RTX 8000).

Models Seconds per step Memory w/ attn dropout Memory w/o attn dropout
Longformer-base 3.22 s/step 34.38 Gb 32.83 Gb
BigBird-RoBERTa-base 2.85 s/step 38.13 Gb 38.13 Gb (no effect)
LSG-RoBERTa 256/0 1.40 s/step 32.92 Gb 24.80 Gb
LSG-RoBERTa 128/128 (norm) 1.51 s/step 33.80 Gb 27.52 Gb
LSG-RoBERTa 32/32 (norm) 1.20 s/step 24.53 Gb 22.53 Gb

attn

Convert checkpoint to LSG

Models can be converted with or without the lsg-converter package.

With package

BERT example with the lsg-converter package:

from lsg_converter import LSGConverter

converter = LSGConverter(max_sequence_length=4096)

# Example 1
model, tokenizer = converter.convert_from_pretrained("bert-base-uncased", num_global_tokens=7)
print(type(model))
# <class 'lsg_converter.bert.modeling_lsg_bert.LSGBertForMaskedLM'>

# Example 2
model, tokenizer = converter.convert_from_pretrained("bert-base-uncased", architecture="BertForSequenceClassification", use_auth_token=True)
print(type(model))
# <class 'lsg_converter.bert.modeling_lsg_bert.LSGBertForSequenceClassification'>

Without package

Use convert_checkpoint.py to convert a model (check model_type from config.json). \ The architecture of the model is inferred from the config file, but you can specify a different one if the config is incorrect (which can happen for BART models), see python convert_checkpoint.py --help. \ To test the converted model, add --run_test (experimental).

RoBERTa example (from RobertaForMaskedLM to RobertaForSequenceClassification) without package:

git clone https://github.com/ccdv-ai/convert_checkpoint_to_lsg.git
cd convert_checkpoint_to_lsg

export MODEL_TO_CONVERT=roberta-base
export MODEL_NAME=lsg-roberta-base
export MAX_LENGTH=4096

python convert_checkpoint.py \
    --initial_model $MODEL_TO_CONVERT \
    --model_name $MODEL_NAME \
    --model_kwargs "{'mask_first_token': true, 'sparsity_type': 'lsh', 'block_size': 32}" \
    --architecture RobertaForSequenceClassification \
    --max_sequence_length $MAX_LENGTH

Model Usage

Works with the AutoClass.

from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load created model
MODEL_NAME = "lsg-roberta-base"
SENTENCE = "This is a test sentence."

model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

inputs = tokenizer(SENTENCE, return_tensors="pt")
model(**inputs)
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load created model
MODEL_NAME = "lsg-roberta-base"
SENTENCE = "This is a test sentence."

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

inputs = tokenizer(SENTENCE, return_tensors="pt")
model(**inputs)

LSG-Layer

You can find a self contained implementation of LSG attention (with no additional params), see lsg_converter/attention_layers. \ Doesn't work for Cross Attention because the local context is ambiguous to define in this case.

Usage:

import torch
from lsg_converter.attention_layers import BlockLocalSelfAttention, LSGSelfAttention

# batch, num_heads, sequence length, hidden_size
n, h, t, d = 2, 4, 58, 32  

Q, K, V = torch.randn(n, h, t, d), torch.randn(n, h, t, d), torch.randn(n, h, t, d)
attention_mask = torch.zeros(n, 1, 1, t).float()

# Only block local attention with 1 global connection
attn = BlockLocalSelfAttention(block_size=16, compute_global_attention=True, is_causal=False, attention_dropout_prob=0.1)

# LSG attention with 1 global connection
attn = LSGSelfAttention(block_size=32, sparsity_factor=8, sparsity_type="bos_pooling", compute_global_attention=True, is_causal=True)

# expect (batch, num_heads, sequence_length, hidden_size) inputs,
# attention_mask is (batch, 1, 1, sequence_length) 
# causal mask is built on the fly but (batch, 1, sequence_length, sequence_length) mask is possible
outputs = attn(Q, K, V, attention_mask)

print(outputs.shape)
> torch.Size([2, 4, 58, 32])

Example: replacing Self Attention in GPT2 (from Huggingface):

from transformers.models.gpt2 import * 
from lsg_converter.attention_layers import BlockLocalSelfAttention

class GPT2BlockLocalAttention(modeling_gpt2.GPT2Attention):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention, layer_idx)
        self.attn = BlockLocalSelfAttention(block_size=32, compute_global_attention=True, is_causal=True)

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        return self.attn(query, key, value, attention_mask), None

modeling_gpt2.GPT2Attention = GPT2BlockLocalAttention

Note that for generation (inference on causal modeling), full attention is used after the first step. \ This may change in the future.

LSG Attention

LSG relies on 3 components: block local attention, sparse attention and prepended global attention.

Parameters

You can change various parameters like :

Sparse selection type

There are 6 different sparse selection patterns. The best type is task dependent. \ If sparse_block_size=0 or sparsity_type="none", only local attention is considered. \ Note that for sequences with length < 2*block_size, the type has no effect.

Experiments

See experiments/