NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.68k stars 2.12k forks source link

Custom Attention implementation not well optimised by TensorRT #3520

Closed david-PHR closed 5 months ago

david-PHR commented 10 months ago

Description

Hi, I am working on a model that employs a modified attention mechanism, incorporating pooling on top of K and V to reduce computational load. However, I'm encountering optimization issues in TensorRT, particularly with my model not being efficiently optimized. I suspect this is due to the unique implementation of attention, which might not align with the MYELIN/Flash Attention pattern matching support in TensorRT.

The key difference is that MiniModelV2 appears to be optimized with Myelin and utilizes the MHA operator at runtime, as per the nsys report. In contrast, MiniModelV1's results, especially its large memory footprint, suggest a lack of effective optimization, potentially lacking Myelin and Flash Attention integration.

Do you have any suggestions on how to address these optimization challenges, particularly for forcing Myelin support in ModelV1? Any insights or feedback on managing memory more efficiently in this context would be greatly appreciated.

ONNX Model V1 and V2:

  1. Model V1 :

    model_v1

  2. Model V2:

model_v2

Environment

TensorRT Version: 8.6

NVIDIA GPU: A10g

MatthieuToulemont commented 10 months ago

I second this. It would also be particularly useful to know how the myelin recognises the attention to use the Flash attention plugin.

Looking at diffusion models, tons of compute is wasted by reprocessing the text embedding at every step of the model in the cross attention when we could compute them only once at the beginning of the loop. But doing so we might break the pattern matching of Flash Attention and loose a lot of inference speed.

zerollzeng commented 10 months ago

@nvpohanh @zhenhuaw-me ^ ^

david-PHR commented 10 months ago

In addition, I would like to know if there is a way to access the cubin file for the multi-heads cross-attention plugin for head dimension == 32. Or even the CUBIN generator for these plugins In the TensorRT OSS archive there are no support for heads dimension of 32 (64 minimum).

nvpohanh commented 10 months ago

Could you try TensorRT 9.2 release? https://github.com/NVIDIA/TensorRT/tree/release/9.2#setting-up-the-build-environment

We have relaxed the MHA pattern matching constraint by a lot between 9.2 and 8.6. Thanks

MatthieuToulemont commented 10 months ago

Hello, thanks for the recommendation.

We looked at this already, but unfortunately it seems Nvidia is not going to include those changes in the NGC containers anytime soon and @david-PHR and I use the Triton Inference Server container to run our models in production.

Would you have any guidelines / steps that we could follow to make sure that Triton Inference Server is compatible with TensorRT 9.2 ?

nvpohanh commented 10 months ago

You can follow these commands to install TRT 9.2 in Triton Inference Server container: https://github.com/NVIDIA/TensorRT/blob/release/9.2/docker/ubuntu-20.04.Dockerfile#L92-L95

david-PHR commented 10 months ago

Is it compatible with Torch-TensorRT?

david-PHR commented 10 months ago

@nvpohanh thanks for the link. When I compile a model with the trtexec executable 9.2 (+ all the libraries installed) I encounter this issue when I try to execute the serialized engine from a tensorrt python wrapper: [12/08/2023-13:57:44] [TRT] [E] 1: [stdArchiveReader.cpp::StdArchiveReaderInitCommon::46] Error Code 1: Serialization (Serialization assertion stdVersionRead == serializationVersion failed.Version tag does not match. Note: Current Version: 236, Serialized Engine Version: 228)

Any idea?

david-PHR commented 10 months ago

I've found the issue, I had to recompile the plugins from TensorRT OSS 9.2 + set up several env variables in my docker session. I still have an issue with Torch-TensorRT that produces SegFault with this new TensorRT installed. It could be useful to have somewhere all the clear steps to upgrade each TensorRT component in a docker session (NGC container for example). Guidelines:

david-PHR commented 10 months ago

For some models I am experiencing this issue : [12/08/2023-14:49:41] [TRT] [W] Using default stream in enqueue()/enqueueV2()/enqueueV3() may lead to performance issues due to additional cudaDeviceSynchronize() calls by TensorRT to ensure correct synchronizations. Please use non-default stream instead. mem_utils.cpp:72: CHECK(status == cudaSuccess) failed. cudaMemcpy2DAsync failed. [12/08/2023-14:49:41] [TRT] [E] 1: [runner.cpp::executeMyelinGraph::682] Error Code 1: Myelin (No Myelin Error exists) [12/08/2023-14:49:41] [TRT] [E] 1: [checkMacros.cpp::catchCudaError::203] Error Code 1: Cuda Runtime (invalid argument) Traceback (most recent call last): Is it expected?

nvpohanh commented 10 months ago

It seems that Torch-TRT is using default stream to call TRT engine, which is not recommended. Let me ask the Torch-TRT internally about this issue.

Meanwhile, could you work around this issue by wrapping your PyTorch code with:

s = torch.cuda.Stream()  # Create a new stream.
with torch.cuda.stream(s):
    # call PyTorch module foward() here
DefTruth commented 8 months ago

@david-PHR hi ~ would you like to share your nsys profile results? I have encounter a similar problem that 'Myelin fused Attn but not run at MHA Kernel', see:

ttyio commented 5 months ago

closing since no activity for more than 3 weeks per our policy, thanks all!