zhangjun / my_notes

0 stars 0 forks source link

FasterTransformer #13

Open zhangjun opened 1 year ago

zhangjun commented 1 year ago

FasterTransformer

build

the xx of -DSM=xx in following scripts means the compute capability of your GPU. For example, 60 (P40) or 61 (P4) or 70 (V100) or 75(T4) or 80 (A100). Default setting is including 70, 75, 80 and 86. Build with PyTorch

git clone https://github.com/NVIDIA/FasterTransformer
cd FasterTransformer
git submodule update --init --recursive
mkdir build && cd build
cmake -DSM=xx -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON ..
make -j12
zhangjun commented 1 year ago

FasterTransformer优化

https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/decoder/Decoder.cc

优缺点

优点

大量transformer架构优化,包括常见fused pattern及对应kernel,完整的int8推理,能直接对接tensorflow、pytorch框架。

缺点

CUDA相关优化点

zhangjun commented 1 year ago

代码解读

Decode

decoding gemm

生成最优gemm配置 src/fastertransformer/models/decoding/decoding_gemm.cc

decoding_gemm
./bin/decoding_gemm <batch_size> <beam_width> <head_number> <size_per_head> <inter_size> <vocab_size> <max_mem_seq_len> <memory_hidden_units> <data_type>
./bin/decoding_gemm 32 4 8 64 2048 30000 32 512 0

Assume the settings of decoding are as follows.

batch_size=32 beam_width=4 head_number=8 size_per_head=64 vocabulary_size=30000 sequence_length=32 encoder's hidden dimension=512 data_type=0 (FP32) or 1 (FP16) or 2 (BF16)

decoding_example

examples/cpp/decoding/decoding_example.cc 提供两种方式从候选结果中选择tokens,一种是beam search,一种是sampling(包括top p和top k)

./bin/decoding_example <batch_size> <beam_width> <head_num> <size_per_head> <inter_size> <vocab_size> <num_layers> <max_seq_len> <memory_max_seq_len> <memory_hidden_units> <top_k> <top_p> <data_type>
./bin/decoding_example 32 4 8 64 2048 30000 6 32 32 512 0 0.0 0

6层transformer的decoder示例

使用top k sampling或者top p sampling

# fp32
./bin/decoding_gemm 32 1 8 64 2048 30000 32 512 0
./bin/decoding_example 32 1 8 64 2048 30000 6 32 32 512 4 0.0 0 # top_k = 4
./bin/decoding_example 32 1 8 64 2048 30000 6 32 32 512 0 0.5 0 # top_p = 0.5

# fp16
./bin/decoding_gemm 32 4 8 64 2048 30000 32 512 1
./bin/decoding_example 32 4 8 64 2048 30000 6 32 32 512 0 0.0 1