Standalone module of Gated Linear Attention (GLA) from Gated Linear Attention Transformers with Hardware-Efficient Training.
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention
Warning: fused_chunk
mode needs Triton2.2 + CUDA12 (See issue). You can use test to quickly see if you can use fused_chunk
mode. If cannot, please refer to link and use chunk
mode instead.
Load the checkpoint from huggingface.
from gla_model import GLAForCausalLM
model = GLAForCausalLM.from_pretrained("bailin28/gla-1B-100B")
vocab_size = model.config.vocab_size
bsz, seq_len = 32, 2048
x = torch.randint(high=vocab_size, size=(bsz, seq_len))
model_output = model(x)
loss = model_output.loss
logits = model_output.logits