jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Jax convolution significantly slower than scipy #5227

Open smartalecH opened 3 years ago

smartalecH commented 3 years ago

I noticed that doing a simple 2D convolution using Jax's scipy backend is significantly slower than using scipy itself:

import numpy as np
import jax.scipy.signal
import scipy.signal

import jax.config
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')

np.random.seed(1)
x  = np.random.rand(300,300)

%timeit jax.scipy.signal.convolve(x,x,mode='same')

%timeit scipy.signal.convolve(x,x,mode='same')

Jax ends with

2.15 s ± 50.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and scipy ends with

12.6 ms ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Is this expected?

shoyer commented 3 years ago

Indeed, JAX is ~200x slower than SciPy here.

I think the difference is that SciPy supports multiple methods for implementing the convolution, which allows it to automatically switch to an asymptotically faster implementation based on FFTs when the convolution window is large: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html

In contrast, JAX wrap's XLA's convolution. XLA is optimized for neural nets, which almost strictly use very small convolution windows (e.g., typically 3x3) so it only supports the equivalent of SciPy's "direct" method.

It would be a very welcome improvement to add support for FFT-based convolutions to JAX!


Side note: there are two useful tricks for profiling in JAX:

  1. Use block_until_ready() to ensure you wait until the computation is done
  2. Wrap the computation in jit and call it once to ensure it is compiled before timing

In this example, this looks like:

@jax.jit
def convolve(x, y):
  return jax.scipy.signal.convolve(x,x,mode='same')

convolve(x,x).block_until_ready()
%timeit convolve(x,x).block_until_ready()

That said, none of these tricks mattered in this case. I still measure ~2 seconds! I think trick (1) may only matter on GPU/TPU, and in this case the computation is so slow that any overhead due to extra tracing/compilation is irrelevant.

mattjj commented 3 years ago

If this is on CPU, and with float64, the XLA convolution kernel might just be a really bad one. That is, I bet the f64 XLA:CPU convolution kernel isn't optimized for much of anything (whereas the f32 CPU ones at least have a chance of calling MKL DNN). This might be a similar story to FFT kernels on CPU.

shoyer commented 3 years ago

Interestingly, float32 on CPU seems to be significantly worse -- 30 seconds per loop!

GPUs are faster (~170 ms/loop in Colab), but still much slower than SciPy.

I think the best fix is adding the alternative FFT based convolutions. This could be done either in XLA or JAX, but is probably easier to implement in JAX.

smartalecH commented 3 years ago

Thanks for the quick feedback!

I built an FFT convolution package using autograd awhile back. It only supports 2d, but it's rather easy to generalize. The performance was the same as scipy/numpy, as expected (for larger arrays of similar size of course).

I can throw a PR together if that's of interest.

Does jax default to XLA's fft for all architectures? For CPU it might be nice to use the fftw library that comes bundled with NumPy/scipy. I also noticed that pocketfft was included in the source.

shoyer commented 3 years ago

Yes, a PR would be greatly appreciated!

JAX uses pocketfft on CPU, which is faster and more accurate than XLA's FFT via eigen. On GPU and TPU, it uses XLA's FFT (which wraps cuFFT on GPU).

shoyer commented 3 years ago

See https://github.com/google/jax/issues/2952 for discussion on FFT libraries.

peterroelants commented 3 years ago

It would be a very welcome improvement to add support for FFT-based convolutions to JAX!

I took a stab at adding FFT-based convolutions in Jax at https://github.com/google/jax/pull/6343. I would love to get some feedback.

hawkinsp commented 2 years ago

This is a CPU and a GPU issue.

On a Colab T4 GPU, I get:

%timeit jax.scipy.signal.convolve(x,x,mode='same').block_until_ready()
1 loop, best of 5: 104 ms per loop
%timeit scipy.signal.convolve(x,x,mode='same')
100 loops, best of 5: 11.2 ms per loop

(edited: I forgot .block_until_ready(), without which the timing is invalid.)