jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

lax.conv_transpose takes FOREVER to compile #17464

Open sokrypton opened 1 year ago

sokrypton commented 1 year ago

Description

I initially submitted the issue here: https://github.com/deepmind/dm-haiku/issues/724

But then realized it was a jax issue.

In short, I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.

import jax
from jax import lax, random
import jax.numpy as jnp
import time

# Directly implement the Conv2DTranspose in JAX
def toy_model_jax(x, params):
    return lax.conv_transpose(x, params["kernel"], strides=(16, 16), padding="VALID")

# Initialize parameters for the toy model
def initialize_params(key):
    kernel_shape = (32, 32, 128, 32)  # (height, width, in_channels, out_channels)
    kernel = random.normal(key, kernel_shape)
    return {"kernel": kernel}

# Generate random input and params
start_time = time.time()
key = random.PRNGKey(42)
x = random.normal(key, (1, 8, 8, 128))
params = initialize_params(key)
end_time = time.time()
print(f"Initialization Run Time: {end_time - start_time:.6f} seconds")

# JIT-compile and time the model run
toy_model_jax_jitted = jax.jit(toy_model_jax)

# Time the model compilation
start_time = time.time()
# Warm-up call (this compiles the function)
_ = toy_model_jax_jitted(x, params)
end_time = time.time()
print(f"JAX Compilation Time: {end_time - start_time:.6f} seconds")

# Time the model run
start_time = time.time()
o = toy_model_jax_jitted(x, params)
print("input_shape", x.shape)
print("output_shape", o.shape)
end_time = time.time()
print(f"JITted Run Time: {end_time - start_time:.6f} seconds")

output

Initialization Run Time: 2.540971 seconds
JAX Compilation Time: 251.976538 seconds
input_shape (1, 8, 8, 128)
output_shape (1, 144, 144, 32)
JITted Run Time: 0.001842 seconds

For comparison, here is the pytorch:

Initialization Time: 0.033582 seconds
input_shape torch.Size([1, 128, 8, 8])
output_shape torch.Size([1, 32, 144, 144])
Run Time: 0.047478 seconds

Google colab notebook: https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr

What jax/jaxlib version are you using?

Google Colab

Which accelerator(s) are you using?

GPU

Additional system info

Google Colab

NVIDIA GPU info

Wed Sep  6 14:21:28 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P0    28W /  70W |  11957MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
hawkinsp commented 1 year ago

I think this in turn is an XLA bug. Opened https://github.com/openxla/xla/issues/5541.

hawkinsp commented 10 months ago

This is apparently due to convolution autotuning: some of the algorithms in cudnn are very slow and we try them all during autotuning. Once autotuning has run we will choose a fast algorithm.

akuegel commented 10 months ago

It seems in this case the same algorithms are returned by heuristics_mode_a and heuristics_mode_b. So when we deduplicate the algorithms to try during autotuning, we can half the compile time. That still means it is slow, but it is a step in the right direction. There is an idea how to potentially speed it up more by stopping an autotuning attempt if the best known runtime is already exceeded, but that will take a bit longer to implement.