NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.57k stars 2.1k forks source link

Attention did not use mha kernel (muti head attention) on orin TRT8.6.10 #3575

Closed Nusselder9 closed 7 months ago

Nusselder9 commented 8 months ago

My model has an attention module like this:

image

It did not use mha kernel on my orin with TensorRT8.6.10 (os 6.0.7.0):

image

However, on x86 TensorRT 8.6.1, it can use mha kernel:

image (2)

I would like to use mha on orin. What can I do? Thanks!

Nusselder9 commented 8 months ago

To avoid uncertain influence, I export a toy attention with seq_length=128, batchsize=1, emd_dim=256 like this:

1

However, it still can NOT use mha kernel on orin TRT8.6.10:

image

Please give some help. Thanks very much!

zerollzeng commented 8 months ago

@nvpohanh I have a vague memory that this is expected(I've seen a internal bug before) that the fused mha kernel doens't enable in TRT 8.6, am I correct?

nvpohanh commented 8 months ago

Is it possible to try TRT 9.2? https://github.com/NVIDIA/TensorRT/blob/release/9.2/docker/ubuntu-22.04.Dockerfile#L92-L95

Nusselder9 commented 8 months ago

Sorry but our business scenario won't be able to upgrade TRT within the near future.

Is there a way to fix or avoid the bug and enable fused mha on TRT 8.6? @zerollzeng @nvpohanh

thx!

nvpohanh commented 8 months ago

Could you try adding a LayerNorm into the network? That will encourage TRT 8.6 to trigger the Transformer-specific fusions.

Nusselder9 commented 8 months ago

Thanks for your reply.@nvpohanh I have tried adding LN after attention, but it does not work.

image image
nvpohanh commented 8 months ago

@Nusselder9 could you share the ONNX with the LayerNorm? TRT 8.6 has quite restricted MHA pattern matching code and we need to find out why it didn't trigger the fusion. TRT 9.2 has much looser checking.

I would also try to make the MHA looks like:

[B, S, H] -MatMul-> [B, S, H] -Reshape-> [B, S, N, h] -Transpose-> [B, N, S, h] -> MatMul -> [B, N, S, S] -> MatMul -> [B, N, S, h] -Transpose-> [B, S, N, h] -Reshape-> [B, S, H] -LayerNorm->...
[B, S, H] -MatMul-> [B, S, H] -Reshape-> [B, S, N, h] -Transpose-> [B, N, h, S] ---^                           ^
[B, S, H] -MatMul-> [B, S, H] -Reshape-> [B, S, N, h] -Transpose-> [B, N, S, h] --------------------------------

where B=1, S=128, H=256, N=8, h=32

Nusselder9 commented 8 months ago

Thanks for your kind reply. The attachment is my attention. attention.zip The zip file has two type of attention and both of them can not use fmha.

nvpohanh commented 8 months ago

@Nusselder9 Could you share the ONNX files with the LayerNorm? Thanks!

Nusselder9 commented 8 months ago

Here is attention with LN. @nvpohanh attention_ln.zip

nvpohanh commented 8 months ago

Filed internal tracker 4438093 . Will let you know if we have some findings. Thanks

nvpohanh commented 8 months ago

Internal investigation shows that TRT 8.6.10 did not have any MHA fusion support on Orin. Could you try TRT 8.6.11?

Nusselder9 commented 7 months ago

thanks, I will try.

lix19937 commented 5 months ago

@nvpohanh I think the question mha kernel is not inaccurate description.

Internal investigation shows that TRT 8.6.10 did not have any MHA fusion support on Orin. Could you try TRT 8.6.11?

If an onnx include standard transformer struct(like ViT decoder), TRT 8.6.11 can open MHA fusion ?

In my opinion, CustomQKVToContextPluginDynamic can do some fusion but it need match some conditions if user use plugin.