Open smartalecH opened 3 years ago
Indeed, JAX is ~200x slower than SciPy here.
I think the difference is that SciPy supports multiple methods for implementing the convolution, which allows it to automatically switch to an asymptotically faster implementation based on FFTs when the convolution window is large: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html
In contrast, JAX wrap's XLA's convolution. XLA is optimized for neural nets, which almost strictly use very small convolution windows (e.g., typically 3x3) so it only supports the equivalent of SciPy's "direct" method.
It would be a very welcome improvement to add support for FFT-based convolutions to JAX!
Side note: there are two useful tricks for profiling in JAX:
block_until_ready()
to ensure you wait until the computation is donejit
and call it once to ensure it is compiled before timingIn this example, this looks like:
@jax.jit
def convolve(x, y):
return jax.scipy.signal.convolve(x,x,mode='same')
convolve(x,x).block_until_ready()
%timeit convolve(x,x).block_until_ready()
That said, none of these tricks mattered in this case. I still measure ~2 seconds! I think trick (1) may only matter on GPU/TPU, and in this case the computation is so slow that any overhead due to extra tracing/compilation is irrelevant.
If this is on CPU, and with float64, the XLA convolution kernel might just be a really bad one. That is, I bet the f64 XLA:CPU convolution kernel isn't optimized for much of anything (whereas the f32 CPU ones at least have a chance of calling MKL DNN). This might be a similar story to FFT kernels on CPU.
Interestingly, float32 on CPU seems to be significantly worse -- 30 seconds per loop!
GPUs are faster (~170 ms/loop in Colab), but still much slower than SciPy.
I think the best fix is adding the alternative FFT based convolutions. This could be done either in XLA or JAX, but is probably easier to implement in JAX.
Thanks for the quick feedback!
I built an FFT convolution package using autograd awhile back. It only supports 2d, but it's rather easy to generalize. The performance was the same as scipy/numpy, as expected (for larger arrays of similar size of course).
I can throw a PR together if that's of interest.
Does jax default to XLA's fft for all architectures? For CPU it might be nice to use the fftw library that comes bundled with NumPy/scipy. I also noticed that pocketfft was included in the source.
Yes, a PR would be greatly appreciated!
JAX uses pocketfft on CPU, which is faster and more accurate than XLA's FFT via eigen. On GPU and TPU, it uses XLA's FFT (which wraps cuFFT on GPU).
See https://github.com/google/jax/issues/2952 for discussion on FFT libraries.
It would be a very welcome improvement to add support for FFT-based convolutions to JAX!
I took a stab at adding FFT-based convolutions in Jax at https://github.com/google/jax/pull/6343. I would love to get some feedback.
This is a CPU and a GPU issue.
On a Colab T4 GPU, I get:
%timeit jax.scipy.signal.convolve(x,x,mode='same').block_until_ready()
1 loop, best of 5: 104 ms per loop
%timeit scipy.signal.convolve(x,x,mode='same')
100 loops, best of 5: 11.2 ms per loop
(edited: I forgot .block_until_ready()
, without which the timing is invalid.)
I noticed that doing a simple 2D convolution using Jax's scipy backend is significantly slower than using scipy itself:
Jax ends with
and scipy ends with
Is this expected?