amirzandieh / HyperAttention

Triton Implementation of HyperAttention Algorithm
Apache License 2.0
46 stars 1 forks source link

HyperAttention: Long-context Attention in Near-Linear Time

Triton Implementation of HyperAttention Algorithm

Requirements

The code requires pytorch and triton. pytorch version 2.0.1 tested, but any version >= 2.0.0 might work. Also makes use of triton implementation of FlashAttention. Flash attention kernel adapted to be compilable with triton version 2.1.0.

How to use

The impelmentation of HyperAttention can be found in hyper_attention.py. An example of usage:

from hyper_attention import HyperAttention

attn = HyperAttention(
    input_dim=64, 
    lsh_num_projs=8,
    block_size=256,
    sample_size=256,
    min_seq_len=2048,
    smooth_block=False,)

attn_output = attn(query, key, value, causal=True)

The module has the following parameters:

Speedup on single attention layer

In this section, we showcase the speedup achieved by HyperAttention in comparison to the Triton implementation of FlashAttention (v2) across a range of sequence lengths. The configuration involves 32 heads and a head_dim 64, and the results are obtained by running the methods on NVIDIA A10 Tensor Core GPUs.

Causal masking (decoder-style attention)

The speedup factors for both the forward pass and forward+backward passes for the attention decoder with causal masking are plotted below. HyperAttention exhibits over a 22x speedup for the forward pass and an over 16x speedup for the combined forward+backward passes when the sequence length is 131k.

No causal masking (encoder-style attention)

The speedup factors for both the forward pass and forward+backward passes in the attention encoder, without causal masking, are shown below. HyperAttention reduces to a notably simpler and more efficient algorithm in the absence of causal masking, avoiding the need for recursive partitioning of the attention matrix. Therefore, HyperAttention showcases remarkable speedups, surpassing 270x acceleration for both the forward pass and the combined forward+backward passes when the sequence length is 131k.