Closed s-zymon closed 4 years ago
Please see the related discussion in #2874.
IIUC #2874 is only about a GPU bug. Perhaps this issue is mainly about CPU, both performance and correctness issues.
I think XLA:CPU is using Eigen's FFT. Maybe it's slow or something. I'll ping the XLA:CPU folks to see if they know anything about it. On the JAX side, we could possibly do a CustomCall into some other FFT implementation on CPU, like we CustomCall into LAPACK kernels for matrix decompositions.
I confirmed with the XLA:CPU folks that XLA is just calling into Eigen here, and "it's possible but unlikely that XLA is doing something bad here that triggers slowness." (I'd like to double-check just by executing this benchmark that Eigen is being multithreaded properly for FFTs, but I'm not sure when I'll get a chance to do that.)
Depending on whether XLA:CPU folks have the bandwidth to improve this, we might want to look into JAX-side solutions. I'll update this thread again when I learn more.
I am working on implementing a Jax backend for kymatio (kymat.io), a Python package implementing the scattering transform.
When we compare the Jax FFT implementation against a closed-form expression of the discrete Fourier transform of a box of ones with dtype
float32 (that is, the Dirichlet kernel) we note a large deviation. Additionally, comparing the results of the Jax FFT with the results of the NumPy and SciPy FFT shows significant discrepancies.
Are there plans to address this? PR #3290 does not appear to have solved this issue.
def box_dirichlet(N, FFT):
x = np.arange(N)
x -= len(x) // 2
n = 16
box = np.abs(x) < n
fbox = np.fft.fftshift(FFT(np.fft.ifftshift(box.astype('float32'))))
fbox = fbox/2/np.pi
k = x / (-x.min()) * np.pi
n = 15
dirichlet = np.sin((n + .5) * k) / (2 * np.pi * np.sin(.5 * k))
dirichlet[int(N/2)] = (n + .5) / np.pi
return dirichlet, fbox
def comparison(dirichlet, fbox):
print("The absolute difference is: ", np.linalg.norm(dirichlet - fbox))
print("The relative difference is: ", np.linalg.norm(dirichlet - fbox)/np.linalg.norm(dirichlet))
#comparison of fft'ed box of ones with dirichlet kernel
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)
The absolute difference is: 0.47198237618500993 The relative difference is: 0.0005201455019338586
dirichlet, numpy_fft = box_dirichlet(2**20., np.fft.fft)
comparison(dirichlet, numpy_fft)
The absolute difference is: 3.6838866295306374e-13 The relative difference is: 4.0598063755531826e-16
dirichlet, scipy_fft = box_dirichlet(2**20., scipy.fft.fft)
comparison(dirichlet, scipy_fft)
The absolute difference is: 0.00010274223272981098 The relative difference is: 1.132264951183364e-07
Even with smaller arrays we note large differences.
dirichlet, jax_fft = box_dirichlet(2**15., jnp.fft.fft)
comparison(dirichlet, jax_fft)
The absolute difference is: 0.025780742747920873 The relative difference is: 0.00016071983551150154
What version of jaxlib
and what hardware platform are you using?
jaxlib
0.1.48 adds 64-bit FFT support on CPU and GPU, which may help if you have accuracy problems. Note also that I believe the NumPy FFT you are comparing it with always computes 64-bit. Can you verify you are using a 64-bit FFT in JAX (i.e., you have 64-bit input types and have JAX_ENABLE_X64
set or similar?)
Hi,
I was using jaxlib
0.1.47 and 0.1.48 on Google Colab, but this is something my collaborators have noticed on their machines too. We are aware that NumPy upcasts to 64-bit, however SciPy and it appears Jax do not.
The input was a box of ones as float 32s. Testing with input as a box of float 64s, we obtain similar inaccuracies. This is with version 0.1.48, using config.update('jax_enable_x64', True)
#comparison of fft'ed box of ones with dirichlet kernel
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)
The absolute difference is: 0.47198237618500993 The relative difference is: 0.0005201455019338586
dirichlet, numpy_fft = box_dirichlet(2**20., np.fft.fft)
comparison(dirichlet, numpy_fft)
The absolute difference is: 3.6838866295306374e-13 The relative difference is: 4.0598063755531826e-16
dirichlet, scipy_fft = box_dirichlet(2**20., scipy.fft.fft)
comparison(dirichlet, scipy_fft)
The absolute difference is: 3.781620542617657e-13 The relative difference is: 4.167513480402116e-16
Out of curiosity, do you see the same results from TensorFlow?
JAX uses Eigen for its FFT implementation on CPU, as does TensorFlow, so one hypothesis is that this is simply due to the quality of the Eigen implementation. That might be nice to verify, if you have time. If they did differ that would be very interesting to know.
It appears that they give the same outputs on both CPU and GPU!
Edit: My interpretation was that Jax is supposed to be as similar as possible to Numpy. Is this interpretation wrong?
"Shouldn't Jax be closer to Numpy?"
Ultimately they are two different pieces of code and they will not act the same in all circumstances. And it's not a goal to precisely match NumPy everywhere.
There are at least three things you could mean: a) JAX should default to float64 precision even when performing float32 FFTs. b) JAX should return a better quality float64 result on CPU. c) JAX should return a better quality float64 result on GPU.
For (a): perhaps. We don't try to follow NumPy precisely, and in a number of cases we default to float32 to be more GPU friendly.
For (b): I suspect we need to find a higher quality implementation of FFT on CPU. The obvious candidate is probably Intel's MKL library.
For (c): JAX uses completely different FFT implementations on CPU and GPU. On GPU it uses cufft (which pretty much everyone uses as far as I am aware). I would actually expect that you would see high quality results on GPU. Can you confirm that you were actually running in 64-bit mode on GPU?
The other results were on CPU, both 32 and 64 bit. Looks like 64 bit on GPU match up.
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)
The absolute difference is: 0.00019981246168760387 The relative difference is: 2.2020219063518382e-07
dirichlet, tf_fft = box_dirichlet(2**20., tf.signal.fft)
comparison(dirichlet, tf_fft)
The absolute difference is: 0.00019981246168760387 The relative difference is: 2.2020219063518382e-07
Thank you for looking into this, @hawkinsp. From our perspective the best thing would be to have the Jax FFT be more accurate (comparable to NumPy, SciPy, and PyFFTW) on the CPU (for both float32 and float64). Plugging into MKL, as you suggest, might be a good idea here.
If this is the Eigen FFT interface that Jax is using, it looks like it supports switching from the default backend (kissfft) to FFTW should be possible by setting a compiler flag.
Licensing might be the trickiest part here. FFTW is GPL and MKL is proprietary.
NumPy uses pocketfft these days. Writing a custom call in JAX to use pocketfft on CPU could be a good option -- or perhaps XLA CPU should use pocketfft.
I can also add that the radio astronomy community would be greatly interested if JAX fft on CPU would be both accurate and fast.
@mattjj Re: is the result the same as with tensorflow? Yes,
With JAX:
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
With Tensorflow:
max: 4.362646594533903e-08
mean: 6.237288307614869e-09
min: 0.0
numpy fft execution time [ms]: 44.88363027572632
jax fft execution time [ms]: 84.56079244613647
tensorflow fft execution time [ms]: 84.12498950958252
@mattjj Has there been any progress in understanding why Jax's fft is ~twice as slow as NumPy on a CPU? I second @Joshuaalbert comment stating that other divisions of astrophysics would also be very interested in a fast and accurate Jax FFT.
I think we're pretty clear on what to do here: replace the Eigen FFT on CPU with something else, probably PocketFFT, same as NumPy. We just need someone to actually do it!
@hawkinsp What about MKL's FFT? It's the fastest that I've seen. FFTW is currently what radio astronomers use, due to it's popularity. Some informative comparison of FFTs is here: https://github.com/project-gemmi/benchmarking-fft/
I think optionally using MKL could be viable, but MKL is closed-source software. At the very least, we want to preserve an open source option.
And FFTW is GPL :/
I think the main limiting factor here is just developer bandwidth on the JAX core team, where we have to balance a lot of considerations (code licenses, ensuring it works in OSS as well as internally at Google, etc).
Until we improve this, it might be useful to look at how you can rig up a call into any implementation you want by registering a custom backend-specific kernel with XLA. One example of how to do that is mpi4jax. You could also look at how JAX calls into LAPACK on CPU and cuSolver on GPU, e.g. starting at lapax.pyx for the CPU stuff.
This issue should be fixed, but it requires a jaxlib rebuild. You can either build from source or wait for us to make a new jaxlib release. Hope that helps!
The updated FFT has been released as part of jaxlib 0.1.57. Hope that helps!
I find the noticeably difference between outputs of
numpy.fft.fft
andjax.numpy.fft.fft
. The difference also changes with different device. Forcpu
device error is bigger than forgpu
device. On the other hand mean absolute error forgpu
implementation offft
from e.g. PyTorch is around 1e-8 which seems reasonably. I guess that might be some minor bug.Second issue is the performance of the
jax.numpy.fft.fft
oncpu
device. I am aware thatjax
is intended for GPU/TPU, but the overhead of jax forfft
usingcpu
seems weirdly big.Below is the simple code for reproduction.