I suspect that the issue might be due to the triton kernel hardcoding the multiplication type, for example: acc0 += tl.sum(a.to(tl.float32) * x0.to(tl.float32)[None, :], 1). However, when I tried changing float32 to float16, there was no change.
In summary
Whether it's gather_gemv, gather_transposed_gemv, or mlp_sparse, there is no improvement compared to the native torch computation of dense gemm. Was the improvement mentioned in the paper observed under the condition of using fp32?
This is the result of the test
Suspicious
I suspect that the issue might be due to the triton kernel hardcoding the multiplication type, for example:
acc0 += tl.sum(a.to(tl.float32) * x0.to(tl.float32)[None, :], 1)
. However, when I tried changing float32 to float16, there was no change.In summary
Whether it's gather_gemv, gather_transposed_gemv, or mlp_sparse, there is no improvement compared to the native torch computation of dense gemm. Was the improvement mentioned in the paper observed under the condition of using fp32?
Test code