Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.46k stars 1.36k forks source link

Performance gap between flash attention, FasterTransformer, TensorRT #5

Open Tengfei09 opened 2 years ago

Tengfei09 commented 2 years ago

After reading your paper, flash attention has indeed achieved a significant speed improvement compared to other algorithms. Thanks for your impressive work!!!

But in industrial scenarios, we prefer to use FasterTransformer and TensorRT's demoBERT to accelerate transformer-based models. Are you interested in comparing the inference performance differences between the three?

Under the same condition,

Fastertransformer achieves 65.53ms, while TensorRT achieves 100.91ms.

tridao commented 2 years ago

This is awesome, thanks for the pointers on FasterTransformer and TensorRT! We've been focusing on training and on long sequences, but we're definitely excited about inference as well. I might have some time next week to set up the inference benchmark and optimize the forward pass for seqlen 384.

I don't have much experience with FasterTransformer, do you have some pointer on what model exactly is running with FasterTransformer (e.g., is the model weight in fp16 or fp32, I assume the data is in fp16, do all the input sequences have the same length or how are their lengths distributed, is the model preprocessed by fusing layer norm with other layers?). Having a script that runs the model with some random data with FasterTransformer would be very helpful for us!

Tengfei09 commented 2 years ago

Thanks for your reply!

FasterTransformer is an inference framework used for all transformer-based models, which implements a highly optimized transformer layer for both the encoder and decoder for inference. The method is, based on CUDA and C++, to re-implement each layer of the Transformer and aggressively complete the layer fusion (MHA, LayerNorm).

when you have a transformer-based model(e.g. BERT or GPT-2), you only need to export model weights from Pytorch&TensorFlow, then launch this project. So you could set the datatype, batch_size, seqLen, hidden size as you want.

Taking the BERT-Large with batch_size=64, seqLen=384, dtype=16 as an example,

[FT][INFO] batch_size 64 seq_len 384 layer 24 FT-CPP-time 65.45 ms (10 iterations) root@96d402509a14:/workspace/FasterTrasnformer# ./build/bin/bert_example 64 24 384 12 64 1 1 [INFO] Device: NVIDIA A100-PCIE-40GB

remove padding version

[FT][WARNING] Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free.Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP [FT][INFO] batch_size 64 seq_len 384 layer 24 FT-CPP-time 33.34 ms (10 iterations)

By the way, if you want to test FastTransformer, you could follow this guide

I also provide my cmd to help you launch this project quickly. ( when measuring its performance under different settings, exporting real model weights is unnecessary, just using dummy data is enough.)

  1. set the environment
    docker run --gpus all -it --rm -v /home/guizili/hantengfei/mma:/workspace/mma_cuda11_6 nvcr.io/nvidia/tensorflow:20.12-tf1-py3
    git clone https://github.com/NVIDIA/FasterTransformer.git
    mkdir -p FasterTransformer/build
    cd FasterTransformer/build
    git submodule init && git submodule update
  2. compile the project
    cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release ..
    make

    3.Using ./bin/bert_gemm to generate the best GEMM configuration.

    ./bin/bert_gemm <batch_size> <sequence_length> <head_number> <size_per_head> <is_use_fp16> <int8_mode>`
  3. Generate the best GEMM configuration when running BERT.
    ./bin/bert_example <batch_size> <num_layers> <sequence_length> <head_number> <size_per_head> <is_use_fp16> <is_remove_padding> <int8_mode> <allow_gemm_test>

    More details could be found in the official guide

tridao commented 2 years ago

Thanks so much for the detailed example! What do you think is the best way to test our method with FasterTransformer? Should we clone FasterTransformer, then change their code to call FlashAttention instead of their current attention implementation, then compile the project and run benchmark? Or is there a better way?

Tengfei09 commented 2 years ago

Maybe it's hard to change their code in such a short time. However, if you find flash attention performs better than the MHA part in FasterTransformer, it's a good chance to upstream your implementation to the master branch of FasterTransformer. Because FasterTransformer is a well-known and out-of-box inference toolbox designed for transformer-based models.

In my opinion, you could provide a detailed performance table of FlashAttention under different settings(e.g. batch size, seqLen, number of layers, precision). Then, following my cmd to launch their project, use nsys to profile kernel execution times. This tool will export the avg time of each called CUDA kernels. So you could find the kernel name and execution time of MHA part

nsys profile --trace cuda --stats=true ./bin/bert_example >> kernel.log