google / uncertainty-baselines

High-quality implementations of standard and SOTA methods on a variety of tasks.
Apache License 2.0
1.45k stars 202 forks source link

ResNet50 BatchEnsemble much slower than expected: `Conv2DBatchEnsemble` less optimized than `Conv2D`? #1328

Open arthur-thuy opened 3 months ago

arthur-thuy commented 3 months ago

Hi,

First of all, thank you for sharing this repository; it is really helpful!

I noticed that the runtimes of the ResNet50 BatchEnsemble model are much longer than the ResNet50 deterministic model. I checked all my code but can't find a mistake. Therefore, I was wondering whether this difference could be due to the fact that the tf.keras.layers.Conv2D layer is heavily optimized, while the ed.layers.Conv2DBatchEnsemble layer is not?

I also have experiments with LeNet-5 models, where batch ensemble takes about 1.2x longer than the deterministic model. Moving to ResNet50, batch ensemble takes about 10x longer than determinstic, a substantial difference with the LeNet-5 experiments. It could be that the lack of optimization is only visible for heavy computations, not for the LeNet-5 toy example.

Any ideas? Thanks!

arthur-thuy commented 2 months ago

I realized that the ed.layers.Conv2DBatchEnsemble layer doesn't use cuDNN because it is a custom layer.

The BatchEnsemble paper writes for a ResNet-32x4: "Although the training duration is longer, BatchEnsemble is still significantly faster than training individual model sequentially." I wonder whether the authors used no cuDNN at all during the experiments, in order to have a fair comparison among the methods.