NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.55k stars 2.1k forks source link

trt8.6.1 Multi-Head Attention(MHA) fusions #2981

Closed enozhu closed 4 months ago

enozhu commented 1 year ago

I test a gpt2 model using trt8.6.1; the gpt2 onnx model is from https://github.com/NVIDIA/TensorRT/tree/release/8.6/demo/HuggingFace/GPT2 build cmd is trtexec --onnx=temp/GPT2/gpt2/GPT2-gpt2/decoder/GPT2-gpt2.onnx --fp16 --minShapes=input_ids:1x1 --maxShapes=input_ids:32x256 --optShapes=input_ids:32x64 --saveEngine=gpt2.engine --buildOnly but I find the MHA kernel not generated when i check the profling, then I test a stable diffusion model, I can find MHA kernel in the profling , so is MHA kernel auto generating or handwriting? why the fusion behavior is different between stable diffusion model and gpt2?

enozhu commented 1 year ago

gpt2 profling 有度20230517165842

stable diffusion profling 有度20230517165408

nvpohanh commented 1 year ago

Did you enable kv-cache? For decoder model, we currently only fuse the MHA if the kv-cache is enabled.

If kv-cache is enabled, there should be a lot of inputs/outputs for the kv-cache. Then, you need to add --inputIOFormats=int32:chw,fp16:chw,fp16:chw,.... --outputIOFormats=fp16:chw,fp16:chw,.... to mark all the kv-cache inputs/outputs as FP16 dtype when using trtexec.

enozhu commented 1 year ago

thanks for replying, i change the model to key_value cache model, and run like this first build 有度20230517201803 then run 有度20230517201852

but mha kernel dosen't apear in the profiling, why? my build process sth wrong?

nvpohanh commented 1 year ago

I will take a look today. My current guess is that seq must be 1 because in GPT inference, there are actually two phases:

nvpohanh commented 1 year ago

In the demo, we build an engine with two optimization profiles: one for context phase and one for generation phase.

Using trtexec requires a lot of flags and is a little tedious. That's why the demo uses Python APIs directly.

enozhu commented 1 year ago

I will take a look today. My current guess is that seq must be 1 because in GPT inference, there are actually two phases:

  • The "Context Phase": the first iteration. In this phase, seq can be any number but seq_decoder_length should be 0.
  • The "Generation Phase": the second iteration and so on. In this phase, seq is 1 but seq_decoder_length can be any number.

thanks , generation phase seq must be 1, as you say, I found mha kernel in the profling

nvpohanh commented 1 year ago

FYI, here are the trtexec I used to replicate what the HF demo does:

Context phase:

trtexec --onnx=GPT2-gpt2-fp16-kv_cache.onnx \
    --fp16 --saveEngine=gpt2_context.engine --useCudaGraph --noDataTransfers --useSpinWait --profilingVerbosity=detailed --verbose \
    --minShapes=input_ids:1x1,past_key_values.0.decoder.key:1x12x0x64,past_key_values.1.decoder.key:1x12x0x64,past_key_values.2.decoder.key:1x12x0x64,past_key_values.3.decoder.key:1x12x0x64,past_key_values.4.decoder.key:1x12x0x64,past_key_values.5.decoder.key:1x12x0x64,past_key_values.6.decoder.key:1x12x0x64,past_key_values.7.decoder.key:1x12x0x64,past_key_values.8.decoder.key:1x12x0x64,past_key_values.9.decoder.key:1x12x0x64,past_key_values.10.decoder.key:1x12x0x64,past_key_values.11.decoder.key:1x12x0x64,past_key_values.0.decoder.value:1x12x0x64,past_key_values.1.decoder.value:1x12x0x64,past_key_values.2.decoder.value:1x12x0x64,past_key_values.3.decoder.value:1x12x0x64,past_key_values.4.decoder.value:1x12x0x64,past_key_values.5.decoder.value:1x12x0x64,past_key_values.6.decoder.value:1x12x0x64,past_key_values.7.decoder.value:1x12x0x64,past_key_values.8.decoder.value:1x12x0x64,past_key_values.9.decoder.value:1x12x0x64,past_key_values.10.decoder.value:1x12x0x64,past_key_values.11.decoder.value:1x12x0x64 \
    --optShapes=input_ids:32x128,past_key_values.0.decoder.key:32x12x0x64,past_key_values.1.decoder.key:32x12x0x64,past_key_values.2.decoder.key:32x12x0x64,past_key_values.3.decoder.key:32x12x0x64,past_key_values.4.decoder.key:32x12x0x64,past_key_values.5.decoder.key:32x12x0x64,past_key_values.6.decoder.key:32x12x0x64,past_key_values.7.decoder.key:32x12x0x64,past_key_values.8.decoder.key:32x12x0x64,past_key_values.9.decoder.key:32x12x0x64,past_key_values.10.decoder.key:32x12x0x64,past_key_values.11.decoder.key:32x12x0x64,past_key_values.0.decoder.value:32x12x0x64,past_key_values.1.decoder.value:32x12x0x64,past_key_values.2.decoder.value:32x12x0x64,past_key_values.3.decoder.value:32x12x0x64,past_key_values.4.decoder.value:32x12x0x64,past_key_values.5.decoder.value:32x12x0x64,past_key_values.6.decoder.value:32x12x0x64,past_key_values.7.decoder.value:32x12x0x64,past_key_values.8.decoder.value:32x12x0x64,past_key_values.9.decoder.value:32x12x0x64,past_key_values.10.decoder.value:32x12x0x64,past_key_values.11.decoder.value:32x12x0x64 \
    --maxShapes=input_ids:32x128,past_key_values.0.decoder.key:32x12x0x64,past_key_values.1.decoder.key:32x12x0x64,past_key_values.2.decoder.key:32x12x0x64,past_key_values.3.decoder.key:32x12x0x64,past_key_values.4.decoder.key:32x12x0x64,past_key_values.5.decoder.key:32x12x0x64,past_key_values.6.decoder.key:32x12x0x64,past_key_values.7.decoder.key:32x12x0x64,past_key_values.8.decoder.key:32x12x0x64,past_key_values.9.decoder.key:32x12x0x64,past_key_values.10.decoder.key:32x12x0x64,past_key_values.11.decoder.key:32x12x0x64,past_key_values.0.decoder.value:32x12x0x64,past_key_values.1.decoder.value:32x12x0x64,past_key_values.2.decoder.value:32x12x0x64,past_key_values.3.decoder.value:32x12x0x64,past_key_values.4.decoder.value:32x12x0x64,past_key_values.5.decoder.value:32x12x0x64,past_key_values.6.decoder.value:32x12x0x64,past_key_values.7.decoder.value:32x12x0x64,past_key_values.8.decoder.value:32x12x0x64,past_key_values.9.decoder.value:32x12x0x64,past_key_values.10.decoder.value:32x12x0x64,past_key_values.11.decoder.value:32x12x0x64 \
    --inputIOFormats=int32:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw \
    --outputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw

Generation phase:

trtexec --onnx=GPT2-gpt2-fp16-kv_cache.onnx \
    --fp16 --saveEngine=gpt2_generation.engine --useCudaGraph --noDataTransfers --useSpinWait --profilingVerbosity=detailed --verbose \
    --minShapes=input_ids:32x1,past_key_values.0.decoder.key:1x12x1x64,past_key_values.1.decoder.key:1x12x1x64,past_key_values.2.decoder.key:1x12x1x64,past_key_values.3.decoder.key:1x12x1x64,past_key_values.4.decoder.key:1x12x1x64,past_key_values.5.decoder.key:1x12x1x64,past_key_values.6.decoder.key:1x12x1x64,past_key_values.7.decoder.key:1x12x1x64,past_key_values.8.decoder.key:1x12x1x64,past_key_values.9.decoder.key:1x12x1x64,past_key_values.10.decoder.key:1x12x1x64,past_key_values.11.decoder.key:1x12x1x64,past_key_values.0.decoder.value:1x12x1x64,past_key_values.1.decoder.value:1x12x1x64,past_key_values.2.decoder.value:1x12x1x64,past_key_values.3.decoder.value:1x12x1x64,past_key_values.4.decoder.value:1x12x1x64,past_key_values.5.decoder.value:1x12x1x64,past_key_values.6.decoder.value:1x12x1x64,past_key_values.7.decoder.value:1x12x1x64,past_key_values.8.decoder.value:1x12x1x64,past_key_values.9.decoder.value:1x12x1x64,past_key_values.10.decoder.value:1x12x1x64,past_key_values.11.decoder.value:1x12x1x64 \
    --optShapes=input_ids:32x1,past_key_values.0.decoder.key:32x12x255x64,past_key_values.1.decoder.key:32x12x255x64,past_key_values.2.decoder.key:32x12x255x64,past_key_values.3.decoder.key:32x12x255x64,past_key_values.4.decoder.key:32x12x255x64,past_key_values.5.decoder.key:32x12x255x64,past_key_values.6.decoder.key:32x12x255x64,past_key_values.7.decoder.key:32x12x255x64,past_key_values.8.decoder.key:32x12x255x64,past_key_values.9.decoder.key:32x12x255x64,past_key_values.10.decoder.key:32x12x255x64,past_key_values.11.decoder.key:32x12x255x64,past_key_values.0.decoder.value:32x12x255x64,past_key_values.1.decoder.value:32x12x255x64,past_key_values.2.decoder.value:32x12x255x64,past_key_values.3.decoder.value:32x12x255x64,past_key_values.4.decoder.value:32x12x255x64,past_key_values.5.decoder.value:32x12x255x64,past_key_values.6.decoder.value:32x12x255x64,past_key_values.7.decoder.value:32x12x255x64,past_key_values.8.decoder.value:32x12x255x64,past_key_values.9.decoder.value:32x12x255x64,past_key_values.10.decoder.value:32x12x255x64,past_key_values.11.decoder.value:32x12x255x64 \
    --maxShapes=input_ids:32x1,past_key_values.0.decoder.key:32x12x255x64,past_key_values.1.decoder.key:32x12x255x64,past_key_values.2.decoder.key:32x12x255x64,past_key_values.3.decoder.key:32x12x255x64,past_key_values.4.decoder.key:32x12x255x64,past_key_values.5.decoder.key:32x12x255x64,past_key_values.6.decoder.key:32x12x255x64,past_key_values.7.decoder.key:32x12x255x64,past_key_values.8.decoder.key:32x12x255x64,past_key_values.9.decoder.key:32x12x255x64,past_key_values.10.decoder.key:32x12x255x64,past_key_values.11.decoder.key:32x12x255x64,past_key_values.0.decoder.value:32x12x255x64,past_key_values.1.decoder.value:32x12x255x64,past_key_values.2.decoder.value:32x12x255x64,past_key_values.3.decoder.value:32x12x255x64,past_key_values.4.decoder.value:32x12x255x64,past_key_values.5.decoder.value:32x12x255x64,past_key_values.6.decoder.value:32x12x255x64,past_key_values.7.decoder.value:32x12x255x64,past_key_values.8.decoder.value:32x12x255x64,past_key_values.9.decoder.value:32x12x255x64,past_key_values.10.decoder.value:32x12x255x64,past_key_values.11.decoder.value:32x12x255x64 \
    --inputIOFormats=int32:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw \
    --outputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw

Currently, the Context phase does not have fused MHA, and we expect to fix it by next TRT version.

The Generation phase (which usually takes a large portion of e2e inference runtime) should have fused MHA.

enozhu commented 1 year ago

so in the trt engine, if we want to use mha kernel, a fixed op pattern and some other conditions (such as seq=1)must be matched, and the mha kernel is handwriting optimized,not auto compling code gen kernel, Am I right?

ttyio commented 4 months ago

closing since no activity for more than 3 weeks per our policy, thanks all!