DerryHub / BEVFormer_tensorrt

BEVFormer inference on TensorRT, including INT8 Quantization and Custom TensorRT Plugins (float/half/half2/int8).
Apache License 2.0
410 stars 67 forks source link

SpatialCrossAttentionTRT don't select valid query, no cuda memory problem here? #71

Closed sherylwang closed 1 year ago

sherylwang commented 1 year ago

Thanks for this great repo! I attempted to implement the original BEVFormer in TensorRT, but encountered issues with the SpatialCrossAttention module. Some of its operations are not well-supported in TensorRT. To address this, I created a no-filter version of the module. However, this implementation resulted in significantly higher CUDA memory usage compared to the original version. I'm wondering if using the SpatialCrossAttentionTRT module would help alleviate this issue?

Best Regards

DerryHub commented 1 year ago

Due to the custom TensorRT plugin MultiScaleDeformableAttnTRT2, which computes in a single cuda kernel and supports float/half2/int8, the SpatialCrossAttentionTRT can save much cuda memory. You can see the comparation in README.

sherylwang commented 1 year ago

Cool! image "FP16 Plugins with nv_half" and "FP16 Plugins with nv_half2" including run MultiScaleDeformableAttnTRT2 in half? Or MultiScaleDeformableAttnTRT2 can reduce cuda memory compared to mmdeploy version even run in float?

DerryHub commented 1 year ago

There is no MultiScaleDeformableAttn plugin in mmdeploy-v0.10.0. The experiments with mmdeploy are using several grid_sampler ops to replace MultiScaleDeformableAttn like mmcv.ops.multi_scale_deform_attn.multi_scale_deformable_attn_pytorch, which is inefficient and needs much cuda memory.

sherylwang commented 1 year ago

I see. I have implemented a MultiScaleDeformableAttn plugin based on a kernel file like "mmcv/ops/csrc/common/cuda/ms_deform_attn_cuda_kernel.cuh", and it produces the same results as "mmcv.ops.multi_scale_deform_attn.multi_scale_deformable_attn_pytorch". I am wondering if replacing this plugin with MultiScaleDeformableAttnTRT2 can reduce the CUDA memory usage in float32. Additionally, I am curious if the plugin can also run in fp16 mode while maintaining the calculation precision.

DerryHub commented 1 year ago

The MultiScaleDeformableAttnTRT2 supports float/half2/int8 and can maintaining the calculation precision in BEVFormer. You can see the details in README. BTW, due to the implementation of int8, the inputs of MultiScaleDeformableAttnTRT2 has some difference with mmcv.ops.multi_scale_deform_attn.multi_scale_deformable_attn_pytorch, you can refer https://github.com/DerryHub/BEVFormer_tensorrt/blob/main/TensorRT/README.md to see the difference.

sherylwang commented 1 year ago

Cool! The README.md file shows that MultiScaleDeformableAttnTRT2 will be faster in fp16 or int8 mode compared to MultiScaleDeformableAttnTRT. It's awesome that it also keeps the calculation precision. What about the CUDA memory cost comparison?

DerryHub commented 1 year ago

The comparation of CUDA memory cost is also in README, Memory col.

sherylwang commented 1 year ago

Sorry! image I couldn't find the memory column in the README.md file. Could it be possible that it hasn't been updated on the GitHub repository yet?

DerryHub commented 1 year ago

I mean the total cuda memory of bevformer tensorrt engine in https://github.com/DerryHub/BEVFormer_tensorrt/blob/main/README.md

sherylwang commented 1 year ago

I see. Thanks for the reply.

sherylwang commented 1 year ago

Hi Please confirm again whether it is MultiScaleDeformableAttnFunction2 corresponding to MultiScaleDeformableAttnTRT2? I noticed that the forward logic of MultiScaleDeformableAttnFunction and MultiScaleDeformableAttnFunction2 is different from that of MultiScaleDeformableAttnFunction_fp32 in the original BEVFormer.

DerryHub commented 1 year ago

Yes, there is something different between MultiScaleDeformableAttnTRT and MultiScaleDeformableAttnFunction_fp32. You can see the detail according to the forward function.

sherylwang commented 1 year ago

Thanks for the replies! Do you have any suggestions for verifying the calculation results of MultiScaleDeformableAttnFunction2 or MultiScaleDeformableAttnTRT2? Due to regulations, I need to modify your implementation to adapt to our code standards and ensure that we obtain the same results as MSDeformableAttention3D or MSDeformableAttention3DTRT. In my opinion, MultiScaleDeformableAttnFunction2 executes the logic of the second half of MSDeformableAttention3D, that is, after calculating the attention weights. image

DerryHub commented 1 year ago

Maybe https://github.com/DerryHub/BEVFormer_tensorrt/blob/main/det2trt/models/utils/test_trt_ops/test_multi_scale_deformable_attn.py can help you.

sherylwang commented 1 year ago

Cool! I see. Much thanks.

sherylwang commented 1 year ago

Hi, there is a small question. Why does MultiScaleDeformableAttnTRT2 not directly input sampling locations but input reference points and sampling offsets, and further calculate sampling locations inside the kernel? Is it to save GPU memory usage?

DerryHub commented 1 year ago

No, this is to avoid some of the int8 accuracy issues. If the addition between reference points and sampling offsets is out of the op, the two tensors may be converted to int8 too early and the result is a serious loss of precision in int8 mode.

sherylwang commented 1 year ago

Understood, it makes a lot of sense. I feel that your thinking on deployment is very valuable. Respect!