Open zhexinli opened 3 months ago
@nvpohanh ^ ^
I need to check internally, but I don't think SDXL INT8 quantizes the MHAs. IIRC, it only quantizes the Convs and Gemms. cc @rajeevsrao
Thanks. Is there any possibility that we can surpass the MHA's seq_len limitation? Or does NV have plan to extend mha_v2 to larger dimention so that diffusion-based model can benefit from int8 mha_v2?
BTW, I'm also confused why mha has this limitation since the int8 multiplication and summation happens on the head_dim (AKA that last dimention) instead of seq_len and it would cause overflow if the M&A result is larger than int32.
When I quantize wav2vec, I encounter the same issue.
Description
Hi, I notice form Issue that the int8 MHA_v2 kernel only supports SeqLen <= 512. I also try on my own diffusion model whose Q shape is (B, N, S, H) and S >> 512. I use pytorch_quantization to insert QDQ in the MHA and convert to TRT. As expected, it breaks into 3 kernels( int8->gemm->FP32->softmax->int8->gemm->int8) and runs slower than fp16 MHA_v2.
But I notice the TensorRT 9.3 oss introduces ammo to quantize SDXL, and by veiwing the code I assume the MHA is also quantized because there are codes dealing with the QKV QDQ fusion. So dose the demo SDXL manage to involke int8 MHA_v2 kernel? I think the SDXL SeqLen is also >> 512. How did the demo quantization manage to utilize int8 MHA?![image](https://github.com/NVIDIA/TensorRT/assets/44015820/8a19042f-9ef8-49aa-a7a5-aafff0f34adf)