Open sokrypton opened 1 year ago
I think this in turn is an XLA bug. Opened https://github.com/openxla/xla/issues/5541.
This is apparently due to convolution autotuning: some of the algorithms in cudnn are very slow and we try them all during autotuning. Once autotuning has run we will choose a fast algorithm.
It seems in this case the same algorithms are returned by heuristics_mode_a and heuristics_mode_b. So when we deduplicate the algorithms to try during autotuning, we can half the compile time. That still means it is slow, but it is a step in the right direction. There is an idea how to potentially speed it up more by stopping an autotuning attempt if the best known runtime is already exceeded, but that will take a bit longer to implement.
Description
I initially submitted the issue here: https://github.com/deepmind/dm-haiku/issues/724
But then realized it was a jax issue.
In short, I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.
output
For comparison, here is the pytorch:
Google colab notebook: https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr
What jax/jaxlib version are you using?
Google Colab
Which accelerator(s) are you using?
GPU
Additional system info
Google Colab
NVIDIA GPU info