Open hawkinsp opened 10 months ago
I would guess the problem here is that both parameter layouts (as well as the output) are user-visible, so XLA can't optimize them?
I ran CUTLASS GEMM for the computation please some data below.
@hwu36, Do you see a variant that can do better than above for this computation on A100?
Note that this issue is specific to what it says on the tin: irrespective of layout, we've chosen a very bad strategy for an s8xs8->s8 matmul. If the user wrote this, you should either: a) error, because this isn't something XLA supports, or b) do it well, and here that means: use tensorcore int matmuls, possibly with transposes if you need, and a cast to int8 at the end.
We should lower the first one as the second one, if it's that much better.
Got it. The second lowering also gets 100 TOPs, but the runs shows that for this computation (batch GEMM A,B,C= 16x1024x1024) we can push up to 369 TOPs with row-col layout and 422 TOPs with col-row (interleaved) layout. Thus, we can and should potentially be able to do better than "second one". correct?
Indeed. Do the best you can, given the user's constraints!
This HLO:
gets around 2.83 TOp/s on A100.
whereas this HLO:
achieves well over 100 TOp/s.
We should lower the first one as the second one, if it's that much better.