berlino / gated_linear_attention

MIT License
97 stars 2 forks source link

Gated Linear Attention Layer

Standalone module of Gated Linear Attention (GLA) from Gated Linear Attention Transformers with Hardware-Efficient Training.

pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

Warning: fused_chunk mode needs Triton2.2 + CUDA12 (See issue). You can use test to quickly see if you can use fused_chunk mode. If cannot, please refer to link and use chunk mode instead.

Usage

Load the checkpoint from huggingface.

from gla_model import GLAForCausalLM
model = GLAForCausalLM.from_pretrained("bailin28/gla-1B-100B")
vocab_size = model.config.vocab_size
bsz, seq_len = 32, 2048
x = torch.randint(high=vocab_size, size=(bsz, seq_len))
model_output = model(x)
loss = model_output.loss
logits = model_output.logits