Open ywrt opened 11 months ago
I suspect no-one has optimized that path yet! Filed https://github.com/openxla/xla/issues/5796
Apparently the main problem is you're using integer matmuls that don't exist in cublas, so they aren't hitting an optimized path from NVIDIA.
I also filed https://github.com/openxla/xla/issues/7117 for the case of slow s8xs8->s8 matmuls. Note that s8xs8->s32 matmuls are much faster on GPUs with tensorcores: can you use those?
In jax, how do you specify that the matmul should be s8xs8->s32 instead of s8xs8->s8?
I see that jax.numpy.matmul has an 'preferred_element_type' argument that's undocumented: Is that what I should be using?
On further investigation, looks like there's a missing '''extra_params=_PREFERRED_ELEMENT_TYPE_DESCRIPTION''' in the _wraps wrapper for matmul. And indeed, most the functions that take 'preferred_element_type' are missing the documentation string.
preferred_element_type
forwards to the argument of the same name in lax.dot_general
. Thanks for pointing out the documentation issue – I'll take a look!
Edit: done in #18647
Description
It looks like integer dot products have poor performance on larger matrix multiplies.
gives
Looking at the compiled code, the float versions call a cuBLAS/GEMM matrix multiply ...
... but the integer version just call a generic dot() which doesn't seem to do any sort of tiling(?)
Am I doing something wrong here? Is the integer performance expected to be this slow on GPU?
What jax/jaxlib version are you using?
jax 4.16
Which accelerator(s) are you using?
GPU 4090
Additional system info
No response
NVIDIA GPU info