myscience / retnet-pytorch

Implementation of Retention-Network in PyTorch
12 stars 1 forks source link
llms pytorch-implementation

Retention Network in Easy PyTorch

A simple and concise implementation of Retentive Networks as introduced in Retentive Network: A Successor to Transformer for Large Language Models (2023). Note that the official implementation is hosted here: Ret-Net @ microsoft/torchscale, however this repo provides a standalone and well documented version that can be useful for didactic purposes.

retnet-performances

Model Architecture

The Retentive Network (Ret-Net) is a Transformer architecture that supports two equivalent formulations for the attention mechanism:

The two formulations are equivalent in terms of the final output, however they offer different computational tradeoffs. The parallel formulation offers high training throughput, while the recurrent formulation improves both memory cost and throughput (smaller latency) at inference time. The Retentive architecture achieves similar scaling law as the Transformer architecture, thus achieving the "impossible triangle" of having training parallelism (good use of GPUs), low-cost inference and strong performances, all the same time.

retnet-architecture

Usage

Basic usage of the RetNet model:

import torch
from src.retnet import RetNet

batch_size = 2
seq_length = 1024
dim_model = 512

model = RetNet(
    num_layer = 6,
    num_heads = 8,
    dim_model = dim_model,
    dropout = 0.1,
    value_factor = 2,
    msr_gate_fn = 'gelu',
    mlp_gate_fn = 'gelu',
    mlp_mult = 4,
    mlp_bias = True,
).cuda()

x = torch.randint(0, 100, (batch_size, seq_length, dim_model), dtype=torch.float32)

# Use num_chunks parameter to switch between the parallel and recurrent forward passes.
parallel_forward  = model(x, attn_mask='causal', num_chunks = None)
recurrent_forward = model(x, attn_mask='causal', num_chunks = 8)

# The two formulations should be consistent
assert torch.allclose(parallel_forward, recurrent_forward)

Model now has support for both attention mask and retention matrix normalizations as described in the paper.


# This is the default model behavior
no_nomalization_forward = model(x, attn_mask='causal', normalize_attn=False, normalize_retn=False, num_chunks = None)

# Normalization can be switched on independently
# ! Please NOTE that normalize_attn=True currently breaks consistency between parallel and recurrent forward
only_attn_norm_forward = model(x, attn_mask='causal', normalize_attn=True, normalize_retn=False, num_chunks = None) 
only_retn_norm_forward = model(x, attn_mask='causal', normalize_attn=False, normalize_retn=True, num_chunks = None)

Multi-Scale Retention

One can use the MultiScaleRetention module as a standalone replacement for a standard attention layer. Again it supports both parallel and recurrent forward passes.

from src.retnet import MultiScaleRetention

batch_size = 2
seq_length = 1024
dim_model = 512

msr = MultiScaleRetention(
    dim_model = dim_model,
    num_heads = 8,
    gate_fn = 'gelu',
    value_factor = 2,
)

x = torch.randint(0, 100, (batch_size, seq_length, dim_model), dtype=torch.float32)

# Use num_chunks parameter to switch between the parallel and recurrent forward passes.
parallel_forward  = model(x, num_chunks = None)
recurrent_forward = model(x, num_chunks = 8)

# The two formulations should be consistent
assert torch.allclose(parallel_forward, recurrent_forward)

Known Issues

Citations

This code is based on the official authors' implementation, which is part of the larger torchscale codebase.

@article{sun2023retentive,
  title={Retentive Network: A Successor to Transformer for Large Language Models},
  author={Sun, Yutao and Dong, Li and Huang, Shaohan and Ma, Shuming and Xia, Yuqing and Xue, Jilong and Wang, Jianyong and Wei, Furu},
  journal={arXiv preprint arXiv:2307.08621},
  year={2023}
}