google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

FFT precision/performance #2952

Closed s-zymon closed 3 years ago

s-zymon commented 4 years ago

I find the noticeably difference between outputs of numpy.fft.fft and jax.numpy.fft.fft. The difference also changes with different device. For cpu device error is bigger than for gpu device. On the other hand mean absolute error for gpu implementation of fft from e.g. PyTorch is around 1e-8 which seems reasonably. I guess that might be some minor bug.

Second issue is the performance of the jax.numpy.fft.fft on cpu device. I am aware that jax is intended for GPU/TPU, but the overhead of jax for fft using cpu seems weirdly big.

Below is the simple code for reproduction.

%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

import time
import numpy as np

import jax
from jax import numpy as jnp

np.random.seed(0)

signal = np.random.randn(2**20)
signal_jax = jnp.array(signal)

jfft = jax.jit(jnp.fft.fft)

X_np = np.fft.fft(signal)
X_jax = jfft(signal_jax)

print(np.mean(np.abs(X_np)))
print('max:\t', jnp.max(jnp.abs(X_np - X_jax)))
print('mean:\t', jnp.mean(jnp.abs(X_np - X_jax)))
print('min:\t', jnp.min(jnp.abs(X_np - X_jax)))

### CPU
# 907.3490574884647
# max:   2.8773885332210747
# mean:  0.3903197564919141
# min:   2.4697454729898156e-05

### GPU
# 907.3490574884647
# max:   0.001166179716824765
# mean:  0.00020841654559267488
# min:   2.741492442122853e-07

R = 100
ts = time.time()
for i in range(R):
    _ = np.fft.fft(signal)
print('numpy fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

# Compile
_ = jfft(signal_jax).block_until_ready()

ts = time.time()
for i in range(R):
    _ = jfft(signal_jax).block_until_ready()
print('jax fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

### CPU
# numpy fft execution time [ms]:     36.75990343093872
# jax fft execution time [ms]:           219.37960147857666

### GPU
# numpy fft execution time [ms]:     38.53107929229736
# jax fft execution time [ms]:           0.38921356201171875
jakevdp commented 4 years ago

Please see the related discussion in #2874.

mattjj commented 4 years ago

IIUC #2874 is only about a GPU bug. Perhaps this issue is mainly about CPU, both performance and correctness issues.

I think XLA:CPU is using Eigen's FFT. Maybe it's slow or something. I'll ping the XLA:CPU folks to see if they know anything about it. On the JAX side, we could possibly do a CustomCall into some other FFT implementation on CPU, like we CustomCall into LAPACK kernels for matrix decompositions.

mattjj commented 4 years ago

I confirmed with the XLA:CPU folks that XLA is just calling into Eigen here, and "it's possible but unlikely that XLA is doing something bad here that triggers slowness." (I'd like to double-check just by executing this benchmark that Eigen is being multithreaded properly for FFTs, but I'm not sure when I'll get a chance to do that.)

Depending on whether XLA:CPU folks have the bandwidth to improve this, we might want to look into JAX-side solutions. I'll update this thread again when I learn more.

MuawizChaudhary commented 4 years ago

I am working on implementing a Jax backend for kymatio (kymat.io), a Python package implementing the scattering transform.

When we compare the Jax FFT implementation against a closed-form expression of the discrete Fourier transform of a box of ones with dtype float32 (that is, the Dirichlet kernel) we note a large deviation. Additionally, comparing the results of the Jax FFT with the results of the NumPy and SciPy FFT shows significant discrepancies.

Are there plans to address this? PR #3290 does not appear to have solved this issue.

def box_dirichlet(N, FFT):
  x = np.arange(N)
  x -= len(x) // 2
  n = 16
  box = np.abs(x) < n
  fbox = np.fft.fftshift(FFT(np.fft.ifftshift(box.astype('float32'))))
  fbox = fbox/2/np.pi
  k = x / (-x.min()) * np.pi
  n = 15
  dirichlet = np.sin((n + .5) * k) / (2 * np.pi * np.sin(.5 * k))
  dirichlet[int(N/2)] = (n + .5) / np.pi
  return dirichlet, fbox

def comparison(dirichlet, fbox):
  print("The absolute difference is: ", np.linalg.norm(dirichlet - fbox))
  print("The relative difference is: ", np.linalg.norm(dirichlet - fbox)/np.linalg.norm(dirichlet))
#comparison of fft'ed box of ones with dirichlet kernel
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.47198237618500993 The relative difference is: 0.0005201455019338586

dirichlet, numpy_fft = box_dirichlet(2**20., np.fft.fft)
comparison(dirichlet, numpy_fft)

The absolute difference is: 3.6838866295306374e-13 The relative difference is: 4.0598063755531826e-16

dirichlet, scipy_fft = box_dirichlet(2**20., scipy.fft.fft)
comparison(dirichlet, scipy_fft)

The absolute difference is: 0.00010274223272981098 The relative difference is: 1.132264951183364e-07

Even with smaller arrays we note large differences.

dirichlet, jax_fft = box_dirichlet(2**15., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.025780742747920873 The relative difference is: 0.00016071983551150154

hawkinsp commented 4 years ago

What version of jaxlib and what hardware platform are you using?

jaxlib 0.1.48 adds 64-bit FFT support on CPU and GPU, which may help if you have accuracy problems. Note also that I believe the NumPy FFT you are comparing it with always computes 64-bit. Can you verify you are using a 64-bit FFT in JAX (i.e., you have 64-bit input types and have JAX_ENABLE_X64 set or similar?)

MuawizChaudhary commented 4 years ago

Hi,

I was using jaxlib 0.1.47 and 0.1.48 on Google Colab, but this is something my collaborators have noticed on their machines too. We are aware that NumPy upcasts to 64-bit, however SciPy and it appears Jax do not.

The input was a box of ones as float 32s. Testing with input as a box of float 64s, we obtain similar inaccuracies. This is with version 0.1.48, using config.update('jax_enable_x64', True)

#comparison of fft'ed box of ones with dirichlet kernel
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.47198237618500993 The relative difference is: 0.0005201455019338586

dirichlet, numpy_fft = box_dirichlet(2**20., np.fft.fft)
comparison(dirichlet, numpy_fft)

The absolute difference is: 3.6838866295306374e-13 The relative difference is: 4.0598063755531826e-16

dirichlet, scipy_fft = box_dirichlet(2**20., scipy.fft.fft)
comparison(dirichlet, scipy_fft)

The absolute difference is: 3.781620542617657e-13 The relative difference is: 4.167513480402116e-16

hawkinsp commented 4 years ago

Out of curiosity, do you see the same results from TensorFlow?

JAX uses Eigen for its FFT implementation on CPU, as does TensorFlow, so one hypothesis is that this is simply due to the quality of the Eigen implementation. That might be nice to verify, if you have time. If they did differ that would be very interesting to know.

MuawizChaudhary commented 4 years ago

It appears that they give the same outputs on both CPU and GPU!

Edit: My interpretation was that Jax is supposed to be as similar as possible to Numpy. Is this interpretation wrong?

hawkinsp commented 4 years ago

"Shouldn't Jax be closer to Numpy?"

Ultimately they are two different pieces of code and they will not act the same in all circumstances. And it's not a goal to precisely match NumPy everywhere.

There are at least three things you could mean: a) JAX should default to float64 precision even when performing float32 FFTs. b) JAX should return a better quality float64 result on CPU. c) JAX should return a better quality float64 result on GPU.

For (a): perhaps. We don't try to follow NumPy precisely, and in a number of cases we default to float32 to be more GPU friendly.

For (b): I suspect we need to find a higher quality implementation of FFT on CPU. The obvious candidate is probably Intel's MKL library.

For (c): JAX uses completely different FFT implementations on CPU and GPU. On GPU it uses cufft (which pretty much everyone uses as far as I am aware). I would actually expect that you would see high quality results on GPU. Can you confirm that you were actually running in 64-bit mode on GPU?

MuawizChaudhary commented 4 years ago

The other results were on CPU, both 32 and 64 bit. Looks like 64 bit on GPU match up.

dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.00019981246168760387 The relative difference is: 2.2020219063518382e-07

dirichlet, tf_fft = box_dirichlet(2**20., tf.signal.fft)
comparison(dirichlet, tf_fft)

The absolute difference is: 0.00019981246168760387 The relative difference is: 2.2020219063518382e-07

janden commented 4 years ago

Thank you for looking into this, @hawkinsp. From our perspective the best thing would be to have the Jax FFT be more accurate (comparable to NumPy, SciPy, and PyFFTW) on the CPU (for both float32 and float64). Plugging into MKL, as you suggest, might be a good idea here.

If this is the Eigen FFT interface that Jax is using, it looks like it supports switching from the default backend (kissfft) to FFTW should be possible by setting a compiler flag.

shoyer commented 4 years ago

Licensing might be the trickiest part here. FFTW is GPL and MKL is proprietary.

NumPy uses pocketfft these days. Writing a custom call in JAX to use pocketfft on CPU could be a good option -- or perhaps XLA CPU should use pocketfft.

Joshuaalbert commented 4 years ago

I can also add that the radio astronomy community would be greatly interested if JAX fft on CPU would be both accurate and fast.

@mattjj Re: is the result the same as with tensorflow? Yes,

With JAX:
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
With Tensorflow:
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   44.88363027572632
jax fft execution time [ms]:     84.56079244613647
tensorflow fft execution time [ms]:  84.12498950958252
pacargile commented 3 years ago

@mattjj Has there been any progress in understanding why Jax's fft is ~twice as slow as NumPy on a CPU? I second @Joshuaalbert comment stating that other divisions of astrophysics would also be very interested in a fast and accurate Jax FFT.

hawkinsp commented 3 years ago

I think we're pretty clear on what to do here: replace the Eigen FFT on CPU with something else, probably PocketFFT, same as NumPy. We just need someone to actually do it!

Joshuaalbert commented 3 years ago

@hawkinsp What about MKL's FFT? It's the fastest that I've seen. FFTW is currently what radio astronomers use, due to it's popularity. Some informative comparison of FFTs is here: https://github.com/project-gemmi/benchmarking-fft/

shoyer commented 3 years ago

I think optionally using MKL could be viable, but MKL is closed-source software. At the very least, we want to preserve an open source option.

mattjj commented 3 years ago

And FFTW is GPL :/

I think the main limiting factor here is just developer bandwidth on the JAX core team, where we have to balance a lot of considerations (code licenses, ensuring it works in OSS as well as internally at Google, etc).

Until we improve this, it might be useful to look at how you can rig up a call into any implementation you want by registering a custom backend-specific kernel with XLA. One example of how to do that is mpi4jax. You could also look at how JAX calls into LAPACK on CPU and cuSolver on GPU, e.g. starting at lapax.pyx for the CPU stuff.

hawkinsp commented 3 years ago

This issue should be fixed, but it requires a jaxlib rebuild. You can either build from source or wait for us to make a new jaxlib release. Hope that helps!

hawkinsp commented 3 years ago

The updated FFT has been released as part of jaxlib 0.1.57. Hope that helps!