I test it with batch_num=1, seq_len=128, head_num=5, head_dim=64. It shows "FMHA Inference took 75.82559204 ms, 17.97742325 GFlop/s, 0.01728598 GB/s INT8 average absolute deviation: 1.552685 %". But in torch model or TensorRT, it just takes 0.5ms per run.
I test it with batch_num=1, seq_len=128, head_num=5, head_dim=64. It shows "FMHA Inference took 75.82559204 ms, 17.97742325 GFlop/s, 0.01728598 GB/s INT8 average absolute deviation: 1.552685 %". But in torch model or TensorRT, it just takes 0.5ms per run.