LorrinWWW / SkipBERT

Code associated with the paper **SkipBERT: Efficient Inference with Shallow Layer Skipping**, at ACL 2022
Apache License 2.0
15 stars 1 forks source link

SkipBERT

Code associated with the paper SkipBERT: Efficient Inference with Shallow Layer Skipping, at ACL 2022

Thank you for your interests! The code is still under construction so should be updated frequently.

Download Pre-trained Checkpoints

Quick Start

import psutil, os
import torch
from skipbert import SkipBertModel
from transformers import BertTokenizerFast, BertConfig

p = psutil.Process(os.getpid())
p.nice(100)  # set process priority
print('nice:', p.nice())
torch.set_num_threads(1) # set num of torch threads

# Input Related
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

inputs = tokenizer(
    ["Good temper decides everything"],
    return_tensors='pt', padding='max_length', max_length=128
)

inputs = {
   k: (v.to(device) if isinstance(v, torch.Tensor) and k != 'input_ids' else v) for k, v in inputs.items()
}

# Model Related
config = BertConfig.from_pretrained(PATH_TO_MODEL)
config.plot_mode = 'plot_passive'

model = SkipBertModel.from_pretrained(PATH_TO_MODEL, config=config)
model.eval()

# Inference
# first time will compute the shallow layers
ret = model(**inputs)

# second time will retrieve hidden states from PLOT
ret = model(**inputs)