google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.28k stars 2.68k forks source link

Support integer convolution on GPU #7637

Open lgeiger opened 2 years ago

lgeiger commented 2 years ago

For research on quantized neural networks and for some performance critical situations it would be excellent if JAX would support fast integer convolutions on GPU to bring the GPU backend inline with CPU and TPU.

This is a follow-up on https://github.com/tensorflow/tensorflow/issues/49140, but I am posting it here as well since to me it seems like the missing pieces lie in the interface between XLA and JAX, but please correct me if I am wrong and close this issue.

XLA added support for a int8,int8->int32 GPU GEMM via cuBLAS in https://github.com/tensorflow/tensorflow/commit/695966634280ff43a84a423ba9a58adc123f710d, however it doesn't seem like it is currently usable when calling the xla.Conv op:

The following snippet throws in the XLA GPU conv rewriter with int8 or int32 intput:

from jax import lax, numpy as jnp

x = jnp.ones((8, 4, 256, 256), dtype=jnp.int8)
w = jnp.ones((64, 4, 3, 3), dtype=jnp.int8)

lax.conv(x, w, window_strides=(1, 1), padding="SAME")
Unimplemented: Integer convolutions for CuDNN must have this pattern:
conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), convert<int32>(int8_y))

Setting preferred_element_type will throw in the XLA CUDNN fuse conv rewriter:

lax.conv(x, w, window_strides=(1, 1), padding="SAME", preferred_element_type=jnp.int32)
Unimplemented: Integer convolutions for CuDNN must have float or int8 output. 
Use convert to cast output to float or the following pattern to int8:
clamp(broadcast(-128), conv(int8_x, int8_w, ...), broadcast(127)).

Checkout this notebook for a full reproduction.

jekbradbury commented 2 years ago

All situations where JAX code causes an XLA crash are bugs (in either JAX or XLA). I'm not actually sure whether the issue here is that XLA doesn't support something or that JAX doesn't forbid it at the Python level (where we aim to identify and error on "this pattern is unimplemented" situations). Maybe @cheshire has more context?

cheshire commented 2 years ago

Unimplemented: Integer convolutions for CuDNN must have float or int8 output.

The error seems to be pretty self-descriptive, right? From the table at https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward it seems that int8xint8->int32 convolutions are not supported, and int8 output has to be used.

lgeiger commented 2 years ago

@jekbradbury Thanks the response. I looked a bit deeper into the XLA CUDNN rewriter and I can successfully run int8 CUDNN convolutions on a T4 by following the rewrite patterns, using either:

@jax.jit
def conv_clamp(x, w):
    y = lax.conv(x, w, window_strides=(1, 1), padding="SAME", preferred_element_type=jnp.int32)
    y = lax.clamp(-128, y, 127)
    return lax.convert_element_type(y, jnp.int8)

conv_clamp(x, w)

or

@jax.jit
def conv_float(x, w):
    y = lax.conv(x, w, window_strides=(1, 1), padding="SAME", preferred_element_type=jnp.int32)
    return lax.convert_element_type(y, jnp.float32)

conv_float(x, w)

Although the performance is slower than using a float32 convolution. I think it would be much more user friendly if it would be possible to do this directly by setting preferred_element_type to either jnp.int8 or jnp.float32. I guess one could add a custom translation for this in JAX, though handling this directly in XLA might be the more natural place.

LionSR commented 1 year ago

It would be very nice if this can be fixed soon, as currently I have to use different codes for running on cpu (int conv) and gpu (float conv and then convert to int).

cheshire commented 1 year ago

@lgeiger above seems to point out that from XLA side the code is already there and there's a workaround? And then the question is of providing a great UI from JAX side? Or am I misreading this?

imoneoi commented 1 year ago

Is there any follow-up to this issue? As @lgeiger did, I also observed that the performance is ~3x slower than float32, it would be great if we could speed up the computation by using int8.

cheshire commented 1 year ago

CC @SandSnip3r @akuegel