openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.58k stars 401 forks source link

Slow s8xs8->s8 matrix multiplication on A100 #7117

Open hawkinsp opened 10 months ago

hawkinsp commented 10 months ago

This HLO:

HloModule jit__einsum, entry_computation_layout={(s8[16,1024,1024]{2,1,0}, s8[16,1024,1024]{2,1,0})->s8[16,1024,1024]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.4 {
  Arg_0.1 = s8[16,1024,1024]{2,1,0} parameter(0), sharding={replicated}
  Arg_1.2 = s8[16,1024,1024]{2,1,0} parameter(1), sharding={replicated}
  ROOT dot.3 = s8[16,1024,1024]{2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, metadata={op_name="jit(_einsum)/jit(main)/dot_general[dimension_numbers=(((2,), (2,)), ((0,), (0,))) precision=None preferred_element_type=int8]" source_file="experimental/users/phawkins/jax/nondet.py" source_line=12}
}

gets around 2.83 TOp/s on A100.

whereas this HLO:

HloModule jit_f, entry_computation_layout={(s8[16,1024,1024]{2,1,0}, s8[16,1024,1024]{2,1,0})->s8[16,1024,1024]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.5 {
  Arg_0.1 = s8[16,1024,1024]{2,1,0} parameter(0), sharding={replicated}
  Arg_1.2 = s8[16,1024,1024]{2,1,0} parameter(1), sharding={replicated}
  dot.3 = s32[16,1024,1024]{2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, metadata={op_name="jit(f)/jit(main)/bij,bkj->bik/dot_general[dimension_numbers=(((2,), (2,)), ((0,), (0,))) precision=None preferred_element_type=int32]" source_file="experimental/users/phawkins/jax/nondet.py" source_line=14}
  ROOT convert.4 = s8[16,1024,1024]{2,1,0} convert(dot.3), metadata={op_name="jit(f)/jit(main)/convert_element_type[new_dtype=int8 weak_type=False]" source_file="experimental/users/phawkins/jax/nondet.py" source_line=14}
}

achieves well over 100 TOp/s.

We should lower the first one as the second one, if it's that much better.

cheshire commented 10 months ago

I would guess the problem here is that both parameter layouts (as well as the output) are user-visible, so XLA can't optimize them?

manishucsd commented 10 months ago

I ran CUTLASS GEMM for the computation please some data below.

Batch GEMM Performance with Integer Tensor Cores on NVIDIA A100 SXM 40GB (Batch GEMM Problem Shape _ batch_count=16, m=1024, n=1024, k=1024)

@hwu36, Do you see a variant that can do better than above for this computation on A100?

hawkinsp commented 10 months ago

Note that this issue is specific to what it says on the tin: irrespective of layout, we've chosen a very bad strategy for an s8xs8->s8 matmul. If the user wrote this, you should either: a) error, because this isn't something XLA supports, or b) do it well, and here that means: use tensorcore int matmuls, possibly with transposes if you need, and a cast to int8 at the end.

manishucsd commented 10 months ago

We should lower the first one as the second one, if it's that much better.

Got it. The second lowering also gets 100 TOPs, but the runs shows that for this computation (batch GEMM A,B,C= 16x1024x1024) we can push up to 369 TOPs with row-col layout and 422 TOPs with col-row (interleaved) layout. Thus, we can and should potentially be able to do better than "second one". correct?

hawkinsp commented 10 months ago

Indeed. Do the best you can, given the user's constraints!