THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.
https://THUDM.github.io/SwissArmyTransformer
Apache License 2.0
951 stars 90 forks source link
pretrained-models pytorch transformer

Introduction

sat(SwissArmyTransformer) is a flexible and powerful library to develop your own Transformer variants.

sat is named after "swiss army knife", meaning that all the models (e.g. BERT, GPT, T5, GLM, CogView, ViT...) share the same backone code and cater for versatile usages with some extra light-weight mixins.

sat is powered by deepspeed-ZeRO and model parallelism, aiming to provide the best practice for pretraining and finetuning large models (100M\~20B parameters).

Install

    pip install SwissArmyTransformer

Features

Quick Tour

The most typical python file to use Bert in sat (for inference) is as follows:

# @File: inference_bert.py
from sat import get_args, get_tokenizer, AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args() 
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
# ...

Then we can run the code via

    SAT_HOME=/path/to/download python inference_bert.py --mode inference

All officially supported model names are in urls.py.

To finetune or pretrain a transformer is also extremely easy!

# @File: finetune_bert.py
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixin

def create_dataset_function(path, args):
    # Here to load the dataset
    # ...
    assert isinstance(dataset, torch.utils.data.Dataset)
    return dataset

def forward_step(data_iterator, model, args, timers):
    inputs = next(data_iterator) # from the dataset of create_dataset_function.
    loss, *others = model(inputs)
    return loss

# Parse args, initialize the environment. This is necessary.
args = get_args() 
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# ONE LINE to train! 
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main(args, 
    model_cls=model, 
    forward_step_function=forward_step, # user define
    create_dataset_function=create_dataset_function # user define
)

Then we can run the code via

deepspeed --include localhost:0,1 finetune_bert.py \
    --experiment-name ftbert \
    --mode finetune --train-iters 1000 --save /path/to/save \
    --train-data /path/to/train --valid-data /path/to/valid \
    --lr 0.00002 --batch-size 8 --zero-stage 1 --fp16

Here we use data-parallel on GPUs 0,1. We can also launch the training on many inter-connected machines via --hostfile /path/to/hostfile. See the tutorial for more details.

To write your own model, you only need to consider the difference between the standard Transformer. For example, if you have a idea to improve the attention operation:

from sat.model import BaseMixin
class MyAttention(BaseMixin):
    def __init__(self, hidden_size):
        super(MyAttention, self).__init__()
        # MyAttention may needs some new params, e.g. a learnable alpha.
        self.learnable_alpha = torch.nn.Parameter(torch.ones(hidden_size))

    # This is a hook function, the name `attention_fn` is special.
    def attention_fn(q, k, v, mask, dropout=None, **kwargs):
        # Code for my attention.
        # ...
        return attention_results

Here attention_fn is a hook function, replacing the default action by the new function. All available hooks are in transformer_defaults.py. Now we can use add_mixin to apply our change to all the transformers, such as BERT, Vit and CogView. See the tutorial for more details.

Tutorials

Citation

Currently we don't have a paper, so you don't need to formally cite us!~

If this project helps your research or engineering, use \footnote{https://github.com/THUDM/SwissArmyTransformer} to mention us and recommend SwissArmyTransformer to others.

The tutorial for contributing sat is on the way!

The project is based on (a user of) DeepSpeed, Megatron-LM and Huggingface transformers. Thanks for their awesome work.