intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
142 stars 43 forks source link

There is large variance in the performance report of the softmax N<=1024. #1566

Closed chengjunlu closed 3 months ago

chengjunlu commented 4 months ago

There is large variance in the softmax with NN<=1024 in the micro-bench.

The possible reason is that the threads spawning overhead is not stable when over-subscribe the GPU with large work group number.

Need to investigate it and check whether the persistent softmax can help to reduce the variance.

chengjunlu commented 4 months ago

There is some outlier statistics in the performance bench-mark. Need to follow that.

softmax-performance:
         N  Triton-GB/s  XeTLA-GB/s  Triton-GB/s-min  XeTLA-GB/s-min  Triton-GB/s-max  XeTLA-GB/s-max  Triton-TFlops  XeTLA-TFlops  Triton-TFlops-min  XeTLA-TFlops-min  Triton-TFlops-max  XeTLA-TFlops-max  Triton-CV  XeTLA-CV
0    256.0   666.959247  751.912338       639.375598      476.625457       708.497308      873.813292       0.666959      0.751912           0.639376          0.476625           0.708497          0.873813   0.020683  0.112134
1   1024.0   852.178476  871.008855       845.625798      794.375734       866.591724     1205.259785       0.852178      0.871009           0.845626          0.794376           0.866592          1.205260   0.006625  0.057388
2   2048.0  1326.681649  924.058975      1152.281316      822.412594      1407.484513     1327.311359       1.326682      0.924059           1.152281          0.822413           1.407485          1.327311   0.047909  0.077342
3   4096.0   777.372987  774.463015       718.202711      716.975062       812.849658     1158.647559       0.777373      0.774463           0.718203          0.716975           0.812850          1.158648   0.030980  0.068119
4   8192.0   797.892135  746.956533       772.431690      724.404855       870.187483      812.062733       0.797892      0.746957           0.772432          0.724405           0.870187          0.812063   0.020727  0.021783
5  16384.0   771.169338  753.176996       761.908050      745.654015       794.752057      782.519359       0.771169      0.753177           0.761908          0.745654           0.794752          0.782519   0.010341  0.009802
6  32768.0   840.465288  839.332881       834.064957      832.409566       848.834653      852.284270       0.840465      0.839333           0.834065          0.832410           0.848835          0.852284   0.005450  0.006[801](https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/9833683619/job/27144280954#step:16:802)
chengjunlu commented 3 months ago

We found the CV value > 10% on SoftMax of shape [4096, 256] of XeTLA kernel. The imbalanced of threads dispatching in the small configuration of the XeTLA softmax is some how not stable as randomly. The reason is that the imbalance of threads dispatching on PVC causes the DataPort is oversubscribed on some XeCores which some other XeCores are not fully used.

The imbalance of threads dispatching.

Image

The time flow of the physical threads are blocked on the congestion of the data port on the XeCore which are oversubscribed .

Image

For comparation, the time flow of the physical threads on XeCore with high performance.

Image