Closed kratsg closed 5 years ago
Thanks for asking this, and the beautiful runnable example!
I think the issue is just 64bit vs 32bit. JAX by default maxes out at 32bit values for ints and floats. We chose that as a default policy because a primary use case of JAX is neural network-based machine learning research. That's different from ordinary NumPy, though, which is very happy to cast things to 64bit values. In fact, that's why we made 32bit the system-wide maximum precision by default: because otherwise users might get annoyed by the NumPy API promoting things to 64bit values all the time, when they're trying to stay in 32bit for neural net training!
To enable 64bit values, I ran your script like this:
$ JAX_ENABLE_X64=1 python issue936.py
Optimization terminated successfully. (Exit mode 0)
Current function value: 7.96982011054e-11
Iterations: 29
Function evaluations: 221
Gradient evaluations: 29
[0.999999 0.99999821 0.9999967 0.99999373 0.9999876 ]
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully. (Exit mode 0)
Current function value: 7.96982011054e-11
Iterations: 29
Function evaluations: 221
Gradient evaluations: 29
[0.999999 0.99999821 0.9999967 0.99999373 0.9999876 ]
When that 64bit flag is switched on, jax.numpy
follow's numpy
's precision semantics very closely, and as you can see it causes the numerics to agree here.
Another way to set that flag is by doing something like this at the top of your main .py file:
from jax.config import config
config.update("jax_enable_x64", True)
You can see a bit more in the short gotchas section of the readme, and in the gotchas notebook.
That project looks really awesome, and we want to make JAX work great with it. Please don't hesitate to open issues like this one as you run into any issues or bugs!
Thanks very much for your prompt feedback @mattjj! This helps and we really appreciate it. :+1:
I think the issue is just 64bit vs 32bit.
I knew about this policy. I'm curious how we still managed to get 64-bit precision on the numpy side with 32b arrays? Is there a way to force 32b for the optimization then if it's still working in 64b? I would assume setting the array's dtype to 32b would correctly set things up.
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
To enable 64bit values, I ran your script like this:
I'll look into the gotchas as well and run our failing test (more complicated minimization) with the 64b switch and see if that happens to resolve things there.
That project looks really awesome, and we want to make JAX work great with it. Please don't hesitate to open issues like this one as you run into any issues or bugs!
Thanks!
I'm curious how we still managed to get 64-bit precision on the numpy side with 32b arrays?
My guess is that somewhere NumPy or scipy.optimize.minimize
is promoting to 64bit precision here, even if you feed in a 32bit input. Indeed, the output in both cases (with JAX_ENABLE_X64=0
) is float64, but perhaps more of the calculations are being done in 32bit precision (namely the evaluation of the objective function) when using jax.numpy
without enabling 32bit values.
This experiment seems like some evidence in that direction:
import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize
def run(np):
def rosen(x):
return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
bounds = np.array([[0.7, 1.3]] * 5)
result = minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})
print(result.x.dtype)
print(rosen(result.x).dtype)
run(onp)
run(np)
In [1]: run issue936.py
Optimization terminated successfully. (Exit mode 0)
Current function value: 7.96982011054e-11
Iterations: 29
Function evaluations: 221
Gradient evaluations: 29
float64
float64
jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully. (Exit mode 0)
Current function value: 848.219909668
Iterations: 1
Function evaluations: 7
Gradient evaluations: 1
float64
float32 # different!
My guess is that somewhere NumPy or
scipy.optimize.minimize
is promoting to 64bit precision here, even if you feed in a 32bit input. Indeed, the output in both cases (withJAX_ENABLE_X64=0
) is float64, but perhaps more of the calculations are being done in 32bit precision (namely the evaluation of the objective function) when usingjax.numpy
without enabling 32bit values.
Interesting! This is not obvious behavior... I'll dig into this a bit, thanks!
Hi, hopefully I'm actually doing something wrong with
jax
orscipy
here, but...This example has the following output:
where:
[0.999999 0.99999821 0.9999967 0.99999373 0.9999876 ]
[1.29999995 0.69999999 0.80000001 1.89999998 1.20000005]
I do not understand why (1) the results are different and (2) why there are so few iterations for jax.numpy case. I suspect this is what's causing differences in the minimization result.
NB: I got this example from scipy docs https://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#nelder-mead-simplex-algorithm-method-nelder-mead
/cc @lukasheinrich @matthewfeickert (affects diana-hep/pyhf#377)