Knowledgator / TurboT5

Truly flash T5 realization!
32 stars 3 forks source link

TurboT5 šŸ¦¾ - boost inference up to 3-5x āš” and run T5 model on 100k tokens šŸ“š.

Turbot

TurboT5 is an optimized version of the T5 model architecture. It supports flash and linear attention mechanisms, significantly reducing memory utilization and latency. This project aims to allow the loading and running of original T5 models on tens of thousands of tokens without retraining while maintaining the same accuracy.

You can run T5-base on 100k tokens on a single L4 GPU (22.5 GB) with meaningful inference speed. This opens applications such as long document summarization, multimodal document processing, and even proteomics and genomics tasks.

[!warning] This project is still under development and may contain some bugs. Feel free to create an issue if you detect a bug or have suggestions for improving the project.

Supported attention mechanism

Flash attention supports various length inputs, and it can dynamically choose the right kernel depending on whether the input length in the batch varies. In the case if they are different flash_attention_with_fusing_bias_varlen kernel is calling. Otherwise, flash_attention_with_fusing_bias will be called.

How to use:

First of all, you need to install the package

pip install turbot5 -U

Then you can import different heads for various purposes; we released more encoder heads for tasks such as token classification, question-answering or text classification and, of course, encoder-decoder heads for conditional generation:

from turbot5 import T5ForConditionalGeneration
from turbot5 import T5Config
from transformers import T5Tokenizer
import torch

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base",
 attention_type = 'flash', #put attention type you want to use
 use_triton=True).to('cuda')

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to('cuda')

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

Below is a dictionary with keys for different attention types:

T5ATTENTION_TYPES = {
    "basic": T5Attention,
    "triton-basic": T5TritonBasicAttention,
    "flash": T5FlashAttention,
    "local": T5LocalAttention,
    "block": T5BlockAttention,
    "transient-global": T5TransientGlobalAttention,
}

Benchmarks

While position bias still requires quadratic memory, current flash attention implementation can reduce it to linear memory requirements. Moreover, combining a flash attention kernel with linear attention mechanisms, like block attention, can make it possible to run the model on hundreds of thousands of tokens.

Memory optimization

Below, you can see the benchmarking of different attention realizations on various sequence lengths. The T5-base was tested in this case with batch size 1 and 128 decoder input ids length. memory usage

Block Attention

We benchmarked T5 with a block-attention mechanism on several datasets, including long context tasks such as summarization, and discovered that while it significantly reduces memory requirements even without retraining the model and continues to demonstrate the same accuracy to full Attention, additional training with this Attention can almost completely vanish the difference.

Fusing attention bias calculation

Compared to the original flashT5 work, we also fused the attention bias calculation. While we recalculate attention bias, it is still much faster and memory efficient than passing it to the kernel. Overall, we achieve up to 20x and more speedup in comparison with the original torch realization of T5. speedup

This means that T5's main disadvantage disappears, and it can be trained and referenced on considerable sequence lengths even with a complete attention mechanism.

Forward-pass execution time

Additionally, we compared the forward pass of T5-base models with different attention types. We varied the input sequence length from 1024 to 8192 tokens, the output sequence length was fixed to 128, and the batch size linearly decreased with sequence length; for 1024 tokens, it was set to 16. Below are the results that we got: forwrard

Future work directions:

Ackowledgment

We use and are inspired by the following works :