Closed enozhu closed 4 months ago
gpt2 profling
stable diffusion profling
Did you enable kv-cache? For decoder model, we currently only fuse the MHA if the kv-cache is enabled.
If kv-cache is enabled, there should be a lot of inputs/outputs for the kv-cache. Then, you need to add --inputIOFormats=int32:chw,fp16:chw,fp16:chw,.... --outputIOFormats=fp16:chw,fp16:chw,....
to mark all the kv-cache inputs/outputs as FP16 dtype when using trtexec.
thanks for replying, i change the model to key_value cache model, and run like this first build then run
but mha kernel dosen't apear in the profiling, why? my build process sth wrong?
I will take a look today. My current guess is that seq
must be 1 because in GPT inference, there are actually two phases:
seq
can be any number but seq_decoder_length
should be 0.seq
is 1 but seq_decoder_length
can be any number.In the demo, we build an engine with two optimization profiles: one for context phase and one for generation phase.
Using trtexec requires a lot of flags and is a little tedious. That's why the demo uses Python APIs directly.
I will take a look today. My current guess is that
seq
must be 1 because in GPT inference, there are actually two phases:
- The "Context Phase": the first iteration. In this phase,
seq
can be any number butseq_decoder_length
should be 0.- The "Generation Phase": the second iteration and so on. In this phase,
seq
is 1 butseq_decoder_length
can be any number.
thanks , generation phase seq must be 1, as you say, I found mha kernel in the profling
FYI, here are the trtexec I used to replicate what the HF demo does:
Context phase:
trtexec --onnx=GPT2-gpt2-fp16-kv_cache.onnx \
--fp16 --saveEngine=gpt2_context.engine --useCudaGraph --noDataTransfers --useSpinWait --profilingVerbosity=detailed --verbose \
--minShapes=input_ids:1x1,past_key_values.0.decoder.key:1x12x0x64,past_key_values.1.decoder.key:1x12x0x64,past_key_values.2.decoder.key:1x12x0x64,past_key_values.3.decoder.key:1x12x0x64,past_key_values.4.decoder.key:1x12x0x64,past_key_values.5.decoder.key:1x12x0x64,past_key_values.6.decoder.key:1x12x0x64,past_key_values.7.decoder.key:1x12x0x64,past_key_values.8.decoder.key:1x12x0x64,past_key_values.9.decoder.key:1x12x0x64,past_key_values.10.decoder.key:1x12x0x64,past_key_values.11.decoder.key:1x12x0x64,past_key_values.0.decoder.value:1x12x0x64,past_key_values.1.decoder.value:1x12x0x64,past_key_values.2.decoder.value:1x12x0x64,past_key_values.3.decoder.value:1x12x0x64,past_key_values.4.decoder.value:1x12x0x64,past_key_values.5.decoder.value:1x12x0x64,past_key_values.6.decoder.value:1x12x0x64,past_key_values.7.decoder.value:1x12x0x64,past_key_values.8.decoder.value:1x12x0x64,past_key_values.9.decoder.value:1x12x0x64,past_key_values.10.decoder.value:1x12x0x64,past_key_values.11.decoder.value:1x12x0x64 \
--optShapes=input_ids:32x128,past_key_values.0.decoder.key:32x12x0x64,past_key_values.1.decoder.key:32x12x0x64,past_key_values.2.decoder.key:32x12x0x64,past_key_values.3.decoder.key:32x12x0x64,past_key_values.4.decoder.key:32x12x0x64,past_key_values.5.decoder.key:32x12x0x64,past_key_values.6.decoder.key:32x12x0x64,past_key_values.7.decoder.key:32x12x0x64,past_key_values.8.decoder.key:32x12x0x64,past_key_values.9.decoder.key:32x12x0x64,past_key_values.10.decoder.key:32x12x0x64,past_key_values.11.decoder.key:32x12x0x64,past_key_values.0.decoder.value:32x12x0x64,past_key_values.1.decoder.value:32x12x0x64,past_key_values.2.decoder.value:32x12x0x64,past_key_values.3.decoder.value:32x12x0x64,past_key_values.4.decoder.value:32x12x0x64,past_key_values.5.decoder.value:32x12x0x64,past_key_values.6.decoder.value:32x12x0x64,past_key_values.7.decoder.value:32x12x0x64,past_key_values.8.decoder.value:32x12x0x64,past_key_values.9.decoder.value:32x12x0x64,past_key_values.10.decoder.value:32x12x0x64,past_key_values.11.decoder.value:32x12x0x64 \
--maxShapes=input_ids:32x128,past_key_values.0.decoder.key:32x12x0x64,past_key_values.1.decoder.key:32x12x0x64,past_key_values.2.decoder.key:32x12x0x64,past_key_values.3.decoder.key:32x12x0x64,past_key_values.4.decoder.key:32x12x0x64,past_key_values.5.decoder.key:32x12x0x64,past_key_values.6.decoder.key:32x12x0x64,past_key_values.7.decoder.key:32x12x0x64,past_key_values.8.decoder.key:32x12x0x64,past_key_values.9.decoder.key:32x12x0x64,past_key_values.10.decoder.key:32x12x0x64,past_key_values.11.decoder.key:32x12x0x64,past_key_values.0.decoder.value:32x12x0x64,past_key_values.1.decoder.value:32x12x0x64,past_key_values.2.decoder.value:32x12x0x64,past_key_values.3.decoder.value:32x12x0x64,past_key_values.4.decoder.value:32x12x0x64,past_key_values.5.decoder.value:32x12x0x64,past_key_values.6.decoder.value:32x12x0x64,past_key_values.7.decoder.value:32x12x0x64,past_key_values.8.decoder.value:32x12x0x64,past_key_values.9.decoder.value:32x12x0x64,past_key_values.10.decoder.value:32x12x0x64,past_key_values.11.decoder.value:32x12x0x64 \
--inputIOFormats=int32:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw \
--outputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw
Generation phase:
trtexec --onnx=GPT2-gpt2-fp16-kv_cache.onnx \
--fp16 --saveEngine=gpt2_generation.engine --useCudaGraph --noDataTransfers --useSpinWait --profilingVerbosity=detailed --verbose \
--minShapes=input_ids:32x1,past_key_values.0.decoder.key:1x12x1x64,past_key_values.1.decoder.key:1x12x1x64,past_key_values.2.decoder.key:1x12x1x64,past_key_values.3.decoder.key:1x12x1x64,past_key_values.4.decoder.key:1x12x1x64,past_key_values.5.decoder.key:1x12x1x64,past_key_values.6.decoder.key:1x12x1x64,past_key_values.7.decoder.key:1x12x1x64,past_key_values.8.decoder.key:1x12x1x64,past_key_values.9.decoder.key:1x12x1x64,past_key_values.10.decoder.key:1x12x1x64,past_key_values.11.decoder.key:1x12x1x64,past_key_values.0.decoder.value:1x12x1x64,past_key_values.1.decoder.value:1x12x1x64,past_key_values.2.decoder.value:1x12x1x64,past_key_values.3.decoder.value:1x12x1x64,past_key_values.4.decoder.value:1x12x1x64,past_key_values.5.decoder.value:1x12x1x64,past_key_values.6.decoder.value:1x12x1x64,past_key_values.7.decoder.value:1x12x1x64,past_key_values.8.decoder.value:1x12x1x64,past_key_values.9.decoder.value:1x12x1x64,past_key_values.10.decoder.value:1x12x1x64,past_key_values.11.decoder.value:1x12x1x64 \
--optShapes=input_ids:32x1,past_key_values.0.decoder.key:32x12x255x64,past_key_values.1.decoder.key:32x12x255x64,past_key_values.2.decoder.key:32x12x255x64,past_key_values.3.decoder.key:32x12x255x64,past_key_values.4.decoder.key:32x12x255x64,past_key_values.5.decoder.key:32x12x255x64,past_key_values.6.decoder.key:32x12x255x64,past_key_values.7.decoder.key:32x12x255x64,past_key_values.8.decoder.key:32x12x255x64,past_key_values.9.decoder.key:32x12x255x64,past_key_values.10.decoder.key:32x12x255x64,past_key_values.11.decoder.key:32x12x255x64,past_key_values.0.decoder.value:32x12x255x64,past_key_values.1.decoder.value:32x12x255x64,past_key_values.2.decoder.value:32x12x255x64,past_key_values.3.decoder.value:32x12x255x64,past_key_values.4.decoder.value:32x12x255x64,past_key_values.5.decoder.value:32x12x255x64,past_key_values.6.decoder.value:32x12x255x64,past_key_values.7.decoder.value:32x12x255x64,past_key_values.8.decoder.value:32x12x255x64,past_key_values.9.decoder.value:32x12x255x64,past_key_values.10.decoder.value:32x12x255x64,past_key_values.11.decoder.value:32x12x255x64 \
--maxShapes=input_ids:32x1,past_key_values.0.decoder.key:32x12x255x64,past_key_values.1.decoder.key:32x12x255x64,past_key_values.2.decoder.key:32x12x255x64,past_key_values.3.decoder.key:32x12x255x64,past_key_values.4.decoder.key:32x12x255x64,past_key_values.5.decoder.key:32x12x255x64,past_key_values.6.decoder.key:32x12x255x64,past_key_values.7.decoder.key:32x12x255x64,past_key_values.8.decoder.key:32x12x255x64,past_key_values.9.decoder.key:32x12x255x64,past_key_values.10.decoder.key:32x12x255x64,past_key_values.11.decoder.key:32x12x255x64,past_key_values.0.decoder.value:32x12x255x64,past_key_values.1.decoder.value:32x12x255x64,past_key_values.2.decoder.value:32x12x255x64,past_key_values.3.decoder.value:32x12x255x64,past_key_values.4.decoder.value:32x12x255x64,past_key_values.5.decoder.value:32x12x255x64,past_key_values.6.decoder.value:32x12x255x64,past_key_values.7.decoder.value:32x12x255x64,past_key_values.8.decoder.value:32x12x255x64,past_key_values.9.decoder.value:32x12x255x64,past_key_values.10.decoder.value:32x12x255x64,past_key_values.11.decoder.value:32x12x255x64 \
--inputIOFormats=int32:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw \
--outputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw
Currently, the Context phase does not have fused MHA, and we expect to fix it by next TRT version.
The Generation phase (which usually takes a large portion of e2e inference runtime) should have fused MHA.
so in the trt engine, if we want to use mha kernel, a fixed op pattern and some other conditions (such as seq=1)must be matched, and the mha kernel is handwriting optimized,not auto compling code gen kernel, Am I right?
closing since no activity for more than 3 weeks per our policy, thanks all!
I test a gpt2 model using trt8.6.1; the gpt2 onnx model is from https://github.com/NVIDIA/TensorRT/tree/release/8.6/demo/HuggingFace/GPT2 build cmd is trtexec --onnx=temp/GPT2/gpt2/GPT2-gpt2/decoder/GPT2-gpt2.onnx --fp16 --minShapes=input_ids:1x1 --maxShapes=input_ids:32x256 --optShapes=input_ids:32x64 --saveEngine=gpt2.engine --buildOnly but I find the MHA kernel not generated when i check the profling, then I test a stable diffusion model, I can find MHA kernel in the profling , so is MHA kernel auto generating or handwriting? why the fusion behavior is different between stable diffusion model and gpt2?