Open lgeiger opened 2 years ago
All situations where JAX code causes an XLA crash are bugs (in either JAX or XLA). I'm not actually sure whether the issue here is that XLA doesn't support something or that JAX doesn't forbid it at the Python level (where we aim to identify and error on "this pattern is unimplemented" situations). Maybe @cheshire has more context?
Unimplemented: Integer convolutions for CuDNN must have float or int8 output.
The error seems to be pretty self-descriptive, right? From the table at https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward it seems that int8xint8->int32 convolutions are not supported, and int8 output has to be used.
@jekbradbury Thanks the response. I looked a bit deeper into the XLA CUDNN rewriter and I can successfully run int8 CUDNN convolutions on a T4 by following the rewrite patterns, using either:
@jax.jit
def conv_clamp(x, w):
y = lax.conv(x, w, window_strides=(1, 1), padding="SAME", preferred_element_type=jnp.int32)
y = lax.clamp(-128, y, 127)
return lax.convert_element_type(y, jnp.int8)
conv_clamp(x, w)
or
@jax.jit
def conv_float(x, w):
y = lax.conv(x, w, window_strides=(1, 1), padding="SAME", preferred_element_type=jnp.int32)
return lax.convert_element_type(y, jnp.float32)
conv_float(x, w)
Although the performance is slower than using a float32 convolution.
I think it would be much more user friendly if it would be possible to do this directly by setting preferred_element_type
to either jnp.int8
or jnp.float32
. I guess one could add a custom translation for this in JAX, though handling this directly in XLA might be the more natural place.
It would be very nice if this can be fixed soon, as currently I have to use different codes for running on cpu (int conv) and gpu (float conv and then convert to int).
@lgeiger above seems to point out that from XLA side the code is already there and there's a workaround? And then the question is of providing a great UI from JAX side? Or am I misreading this?
Is there any follow-up to this issue? As @lgeiger did, I also observed that the performance is ~3x slower than float32, it would be great if we could speed up the computation by using int8.
CC @SandSnip3r @akuegel
For research on quantized neural networks and for some performance critical situations it would be excellent if JAX would support fast integer convolutions on GPU to bring the GPU backend inline with CPU and TPU.
This is a follow-up on https://github.com/tensorflow/tensorflow/issues/49140, but I am posting it here as well since to me it seems like the missing pieces lie in the interface between XLA and JAX, but please correct me if I am wrong and close this issue.
XLA added support for a
int8,int8->int32
GPU GEMM via cuBLAS in https://github.com/tensorflow/tensorflow/commit/695966634280ff43a84a423ba9a58adc123f710d, however it doesn't seem like it is currently usable when calling thexla.Conv
op:The following snippet throws in the XLA GPU conv rewriter with
int8
orint32
intput:Setting
preferred_element_type
will throw in the XLA CUDNN fuse conv rewriter:Checkout this notebook for a full reproduction.