Open dan-kazbek opened 1 week ago
I am facing similar issue
cuDNN is highly optimized for specific operations and configurations, especially when parameters align well with cuDNN's pre-tuned algorithms. When values like the number of groups in a convolution layer are adjusted, cuDNN may not perform as efficiently, as the algorithm selection may not be optimal for those configurations.
Have you tried the JAX backend? what are you observing then?
cuDNN is highly optimized for specific operations and configurations, especially when parameters align well with cuDNN's pre-tuned algorithms. When values like the number of groups in a convolution layer are adjusted, cuDNN may not perform as efficiently, as the algorithm selection may not be optimal for those configurations.
Have you tried the JAX backend? what are you observing then?
How do I know which values of the number of groups align well with cuDNN?
The difference on JAX backend is not as stark, but the grouped convolutions are still slower:
Average time with 2 groups: 1.5100 ms
Standard deviation: 0.0195 ms
Average time without groups: 1.1997 ms
Standard deviation: 0.2183 ms
The code that I used to produce these numbers:
import os
from statistics import mean, stdev
os.environ["KERAS_BACKEND"] = "jax"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import jax
import timeit
from keras import layers
conv_layer_groups = layers.Conv2D(
kernel_size=1,
filters=128,
groups=2,
)
conv_layer_no_groups = layers.Conv2D(
kernel_size=1,
filters=128,
groups=1,
)
batch_size = 32
layer_input = jax.random.normal(
key=jax.random.key(0), shape=(batch_size, 64, 64, 512)
)
num_runs = 10
num_repeats = 20
times = timeit.repeat(
"conv_layer_groups(layer_input).block_until_ready()",
setup="from __main__ import conv_layer_groups, layer_input",
repeat=num_repeats,
number=num_runs,
)
times = [time / num_runs for time in times]
print(f"Average time with 2 groups: {mean(times[1:])*1000 : .4f} ms")
print(f"Standard deviation: {stdev(times[1:])*1000 : .4f} ms")
times = timeit.repeat(
"conv_layer_no_groups(layer_input).block_until_ready()",
setup="from __main__ import conv_layer_no_groups, layer_input",
repeat=num_repeats,
number=num_runs,
)
times = [time / num_runs for time in times]
print(f"Average time without groups: {mean(times[1:])*1000 : .4f} ms")
print(f"Standard deviation: {stdev(times[1:])*1000 : .4f} ms")
As the title says, for some parameter values the inference time for grouped convolutions the inference time is slower than for regular convolutions (i.e. number of groups = 1). Standalone code to reproduce the issue:
The output on my setup:
The inference time is 25 times slower if I use number of groups = 2, even though it should reduce the required number of FLOPs by 2.
If I use 3D convolutions, the difference is even larger, and it throws an additional XLA warning:
The output:
Some of the system parameters if they are needed: