syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

RetNet

A huggingface transformer compatible implementation of Retention Networks. (https://arxiv.org/pdf/2307.08621.pdf) The implementation is on par with the official implementation at torchscale repo.

Supports three types of implementations: parallel, recurrent, chunkwise.

Check play.ipynb for minimal testing of parallel, recurrent, and chunkwise forward.

Getting Started

Using PyTorch and huggingface transformers. Also, we need timm for droppath in torchscale.

pip install torch transformers timm
# pip install apex (optional)
# pip install pytest (to run tests/)
# pip install fire (to run convert_weights.py)

You may want to use conda.

Quick Examples

Take a look at play.ipynb.

import torch
from retnet.modeling_retnet import RetNetModel
from retnet.configuration_retnet import RetNetConfig

config = RetNetConfig(decoder_layers=8,
                      decoder_embed_dim=512,
                      decoder_value_embed_dim=1024,
                      decoder_retention_heads=4,
                      decoder_ffn_embed_dim=1024)
model = RetNetModel(config)

input_ids = torch.LongTensor([[1,2,3,4,5,6,7,8]])

parallel_outputs = model(input_ids, forward_impl='parallel', use_cache=True)
parallel_state = parallel_outputs.last_hidden_state
parallel_cache = parallel_outputs.past_key_values

past_kv = None
rnn_state = []
for i in range(input_ids.shape[1]):
    rnn_out = model(input_ids[:, :i+1], forward_impl='recurrent', past_key_values=past_kv, use_cache=True)
    rnn_state.append(rnn_out.last_hidden_state)
    past_kv = rnn_out.past_key_values
rnn_state = torch.cat(rnn_state, dim=1)
rnn_cache = rnn_out.past_key_values

chunk_outputs = model(input_ids, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=4)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values

Language Generation

import torch
from retnet.modeling_retnet import RetNetForCausalLM
from retnet.configuration_retnet import load_config_from_json
from transformers import AutoTokenizer

config = load_config_from_json('configs/retnet-base/config.json')
model = RetNetForCausalLM(config)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = 4096
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer("Retention refers to", return_tensors='pt')

# parallel forward
# our custom generate function
generated = model.custom_generate(**inputs, parallel_compute_prompt=True, max_new_tokens=20)
# huggingface's generate. Both should be equivalent
generated = model.generate(**inputs, max_new_tokens=20)
tokenizer.batch_decode(generated)
# NOTE: this should be gibberish, since the model is not trained.

Huggingface Integration

Now the model supports full huggingface integration (except for things I don't realize :)). It can be trained with huggingface Trainer, can be saved and loaded with save_pretrained or from_pretrained, generate with .generate.

Minimal Training Example

You can train RetNet with huggingface Trainer API. Refer to train.py.

export CUDA_VISIBLE_DEVICES=0

python train.py \
    --model_size 300m \
    --output_dir checkpoints \
    --do_train --do_eval \
    --prediction_loss_only \
    --remove_unused_columns False \
    --learning_rate 6e-4 \
    --weight_decay 0.01 \
    --max_steps 20000 \
    --logging_steps 10 \
    --eval_steps 1000 \
    --save_steps 1000 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 16

Some Useful Notes

xpos note

The authors mention xpos as $e^{in\theta}, e^{-im\theta}$ (equation 5). At first glance, this is a complex number, which is difficult to use and interpret. However, this is in fact xpos, which was made clear by this lecture note for me. The gist is:

$$ R_{\theta} = e^{i\theta} = \begin{bmatrix} cos(\theta) & -sin(\theta); \ sin(\theta) & cos(\theta) \end{bmatrix}$$

Since xpos (which builds on RoPE) precisely does such a rotation, this is in fact, xpos. I used the implementation of xpos fould in torchscale repo with 1 small change: instead of negative min_pos, I used min_pos=0 (line 53, 54), so that it is recurrence friendly.

Decay Note

Equation 7 omits an important detail: there should be an extra decay applied to $K^T{[i]}V{[i]}$, which is $D_{B}$, i.e. the last row of the inner_chunk decay_mask. So, it should be re-written as:

$$Ri = K^T{[i]}V{[i]} \odot D{B} + \gamma ^B R_{i-1}$$

This is implemented in the chunkwise_retention function, named as intra_decay.

This idea can also be applied to parallel_retention to obtain the correct past_kv that can be further fed into recurrent or chunkwise retention in the next token steps.

Configs

The configs/ folder includes example configurations ~listed in the paper~ found in torchscale repo for different sizes. For simplicity, I used GPT2 tokenizer, and hence the model has 50257 as vocab size for default (this can change when microsoft release the official weight).