google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.88k stars 2.73k forks source link

larger integer matrix multiplies are > 20x slower versus floats. #17688

Open ywrt opened 11 months ago

ywrt commented 11 months ago

Description

It looks like integer dot products have poor performance on larger matrix multiplies.

def measure_type(dtype):
  def fn():
    a = jnp.ones((128, 1024, 256), dtype=dtype)
    b = jnp.ones((256, 512), dtype=dtype)
    return jnp.matmul(a, b)
  jfn = jax.jit(fn)

  #print(jfn.lower().compile().as_text())
  jfn().block_until_ready()
  LOOPS = 2000
  st = time.time()
  for _ in range(LOOPS):
    x = jfn()
    x.block_until_ready()
  print(f'{LOOPS/(time.time() - st):7.2f} loops/sec for {dtype}')

measure_type(jnp.float32)
measure_type(jnp.float16)
measure_type(jnp.int16)
measure_type(jnp.int8)

gives

1463.92 loops/sec for <class 'jax.numpy.float32'>
3087.09 loops/sec for <class 'jax.numpy.float16'>
 133.42 loops/sec for <class 'jax.numpy.int16'>
 148.95 loops/sec for <class 'jax.numpy.int8'>

Looking at the compiled code, the float versions call a cuBLAS/GEMM matrix multiply ...

  %custom-call = f16[131072,512]{1,0} custom-call(f16[131072,256]{1,0} %wrapped_broadcast.5, f16[256,512]{1,0} %wrapped_broadcast.6), custom_call_target="__cublas$gemm", metadata={op_name="jit(fn)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=float16]" }, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}

... but the integer version just call a generic dot() which doesn't seem to do any sort of tiling(?)

  ROOT %dot.1 = s16[131072,512]{1,0} dot(s16[131072,256]{1,0} %param_0.2, s16[256,512]{1,0} %param_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(fn)/jit(main)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=int16]" }

Am I doing something wrong here? Is the integer performance expected to be this slow on GPU?

What jax/jaxlib version are you using?

jax 4.16

Which accelerator(s) are you using?

GPU 4090

Additional system info

No response

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:03:00.0 Off |                  Off |
| 33%   61C    P2   384W / 450W |  22691MiB / 24564MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:09:00.0 Off |                  Off |
| 30%   40C    P8    16W / 450W |      3MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
hawkinsp commented 11 months ago

I suspect no-one has optimized that path yet! Filed https://github.com/openxla/xla/issues/5796

hawkinsp commented 10 months ago

Apparently the main problem is you're using integer matmuls that don't exist in cublas, so they aren't hitting an optimized path from NVIDIA.

I also filed https://github.com/openxla/xla/issues/7117 for the case of slow s8xs8->s8 matmuls. Note that s8xs8->s32 matmuls are much faster on GPUs with tensorcores: can you use those?

ywrt commented 9 months ago

In jax, how do you specify that the matmul should be s8xs8->s32 instead of s8xs8->s8?

I see that jax.numpy.matmul has an 'preferred_element_type' argument that's undocumented: Is that what I should be using?

On further investigation, looks like there's a missing '''extra_params=_PREFERRED_ELEMENT_TYPE_DESCRIPTION''' in the _wraps wrapper for matmul. And indeed, most the functions that take 'preferred_element_type' are missing the documentation string.

jakevdp commented 9 months ago

preferred_element_type forwards to the argument of the same name in lax.dot_general. Thanks for pointing out the documentation issue – I'll take a look!

Edit: done in #18647