keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.08k stars 19.48k forks source link

Slower inference time when using grouped convolutions compared to regular convolutions #20471

Open dan-kazbek opened 1 week ago

dan-kazbek commented 1 week ago

As the title says, for some parameter values the inference time for grouped convolutions the inference time is slower than for regular convolutions (i.e. number of groups = 1). Standalone code to reproduce the issue:

import tensorflow as tf
from keras import layers

conv_layer_groups = layers.Conv2D(
    kernel_size=1,
    filters=128,
    groups=2,
)

conv_layer_no_groups = layers.Conv2D(
    kernel_size=1,
    filters=128,
    groups=1,
)

batch_size = 32
layer_input = tf.random.normal((batch_size, 64, 64, 512))

start = tf.timestamp()
output = conv_layer_groups(layer_input)
end = tf.timestamp()
print(f"Time of inference with 2 groups: {(end - start).numpy():.4f} seconds")

start = tf.timestamp()
output = conv_layer_no_groups(layer_input)
end = tf.timestamp()
print(f"Time of inference without groups: {(end - start).numpy():.4f} seconds")

The output on my setup:

2024-11-08 13:32:16.361610: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79078 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:32:00.0, compute capability: 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1731072737.629509   23627 service.cc:146] XLA service 0x719db50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731072737.629566   23627 service.cc:154]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
2024-11-08 13:32:17.700106: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8906
I0000 00:00:1731072738.129561   23627 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Time of inference with 2 groups: 0.5544 seconds
Time of inference without groups: 0.0224 seconds

The inference time is 25 times slower if I use number of groups = 2, even though it should reduce the required number of FLOPs by 2.

If I use 3D convolutions, the difference is even larger, and it throws an additional XLA warning:

import tensorflow as tf
from keras import layers

conv_layer_groups = layers.Conv3D(
    kernel_size=1,
    filters=128,
    groups=2,
)

conv_layer_no_groups = layers.Conv3D(
    kernel_size=1,
    filters=128,
    groups=1,
)

batch_size = 32
layer_input = tf.random.normal((batch_size, 64, 64, 64, 512))

start = tf.timestamp()
output = conv_layer_groups(layer_input)
end = tf.timestamp()
print(f"Time of inference with 2 groups: {(end - start).numpy():.4f} seconds")

start = tf.timestamp()
output = conv_layer_no_groups(layer_input)
end = tf.timestamp()
print(f"Time of inference without groups: {(end - start).numpy():.4f} seconds")

The output:

2024-11-08 13:43:01.360942: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79078 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:32:00.0, compute capability: 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1731073383.401272   24815 service.cc:146] XLA service 0x8eb4f40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731073383.401320   24815 service.cc:154]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
2024-11-08 13:43:03.480806: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8906
2024-11-08 13:43:04.817905: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng4{} for conv (f32[32,128,64,64,64]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[32,512,64,64,64]{4,3,2,1,0}, f32[128,256,1,1,1]{4,3,2,1,0}), window={size=1x1x1}, dim_labels=bf012_oi012->bf012, feature_group_count=2, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-11-08 13:43:06.818705: E external/local_xla/xla/service/slow_operation_alarm.cc:133] The operation took 3.000939004s
Trying algorithm eng4{} for conv (f32[32,128,64,64,64]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[32,512,64,64,64]{4,3,2,1,0}, f32[128,256,1,1,1]{4,3,2,1,0}), window={size=1x1x1}, dim_labels=bf012_oi012->bf012, feature_group_count=2, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
I0000 00:00:1731073386.965210   24815 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Time of inference with 2 groups: 3.7136 seconds
Time of inference without groups: 0.0376 seconds

Some of the system parameters if they are needed:

Ubuntu 22.04
Tensorflow 2.17.0
Keras 3.4.1
NVIDIA A100 80GB
CUDA 12.3, cuDNN 8.9.6
IMvision12 commented 1 week ago

I am facing similar issue

divyashreepathihalli commented 4 days ago

cuDNN is highly optimized for specific operations and configurations, especially when parameters align well with cuDNN's pre-tuned algorithms. When values like the number of groups in a convolution layer are adjusted, cuDNN may not perform as efficiently, as the algorithm selection may not be optimal for those configurations.

Have you tried the JAX backend? what are you observing then?

dan-kazbek commented 3 days ago

cuDNN is highly optimized for specific operations and configurations, especially when parameters align well with cuDNN's pre-tuned algorithms. When values like the number of groups in a convolution layer are adjusted, cuDNN may not perform as efficiently, as the algorithm selection may not be optimal for those configurations.

Have you tried the JAX backend? what are you observing then?

How do I know which values of the number of groups align well with cuDNN?

The difference on JAX backend is not as stark, but the grouped convolutions are still slower:

Average time with 2 groups:  1.5100 ms
Standard deviation:  0.0195 ms
Average time without groups:  1.1997 ms
Standard deviation:  0.2183 ms

The code that I used to produce these numbers:

import os
from statistics import mean, stdev

os.environ["KERAS_BACKEND"] = "jax"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import jax
import timeit
from keras import layers

conv_layer_groups = layers.Conv2D(
    kernel_size=1,
    filters=128,
    groups=2,
)

conv_layer_no_groups = layers.Conv2D(
    kernel_size=1,
    filters=128,
    groups=1,
)

batch_size = 32
layer_input = jax.random.normal(
    key=jax.random.key(0), shape=(batch_size, 64, 64, 512)
)

num_runs = 10
num_repeats = 20
times = timeit.repeat(
    "conv_layer_groups(layer_input).block_until_ready()",
    setup="from __main__ import conv_layer_groups, layer_input",
    repeat=num_repeats,
    number=num_runs,
)
times = [time / num_runs for time in times]
print(f"Average time with 2 groups: {mean(times[1:])*1000 : .4f} ms")
print(f"Standard deviation: {stdev(times[1:])*1000 : .4f} ms")

times = timeit.repeat(
    "conv_layer_no_groups(layer_input).block_until_ready()",
    setup="from __main__ import conv_layer_no_groups, layer_input",
    repeat=num_repeats,
    number=num_runs,
)
times = [time / num_runs for time in times]
print(f"Average time without groups: {mean(times[1:])*1000 : .4f} ms")
print(f"Standard deviation: {stdev(times[1:])*1000 : .4f} ms")