jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

scipy.optimize.minimize with method=SLSQP does not minimize successfully #936

Closed kratsg closed 5 years ago

kratsg commented 5 years ago

Hi, hopefully I'm actually doing something wrong with jax or scipy here, but...

import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize

def run(np):
    def rosen(x):
        return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

    x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
    bounds = np.array([[0.7, 1.3]] * 5)
    return minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})

print(run(onp).x)
print(run(np).x)

This example has the following output:

Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.969820110544395e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
[0.999999   0.99999821 0.9999967  0.99999373 0.9999876 ]
/Users/kratsg/.virtualenvs/pyhf/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 848.2199096679688
            Iterations: 1
            Function evaluations: 7
            Gradient evaluations: 1
[1.29999995 0.69999999 0.80000001 1.89999998 1.20000005]

where:

I do not understand why (1) the results are different and (2) why there are so few iterations for jax.numpy case. I suspect this is what's causing differences in the minimization result.

NB: I got this example from scipy docs https://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#nelder-mead-simplex-algorithm-method-nelder-mead

/cc @lukasheinrich @matthewfeickert (affects diana-hep/pyhf#377)

mattjj commented 5 years ago

Thanks for asking this, and the beautiful runnable example!

I think the issue is just 64bit vs 32bit. JAX by default maxes out at 32bit values for ints and floats. We chose that as a default policy because a primary use case of JAX is neural network-based machine learning research. That's different from ordinary NumPy, though, which is very happy to cast things to 64bit values. In fact, that's why we made 32bit the system-wide maximum precision by default: because otherwise users might get annoyed by the NumPy API promoting things to 64bit values all the time, when they're trying to stay in 32bit for neural net training!

To enable 64bit values, I ran your script like this:

$ JAX_ENABLE_X64=1 python issue936.py
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.96982011054e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
[0.999999   0.99999821 0.9999967  0.99999373 0.9999876 ]
/usr/local/google/home/mattjj/packages/jax/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.96982011054e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
[0.999999   0.99999821 0.9999967  0.99999373 0.9999876 ]

When that 64bit flag is switched on, jax.numpy follow's numpy's precision semantics very closely, and as you can see it causes the numerics to agree here.

Another way to set that flag is by doing something like this at the top of your main .py file:

from jax.config import config
config.update("jax_enable_x64", True)

You can see a bit more in the short gotchas section of the readme, and in the gotchas notebook.

mattjj commented 5 years ago

That project looks really awesome, and we want to make JAX work great with it. Please don't hesitate to open issues like this one as you run into any issues or bugs!

matthewfeickert commented 5 years ago

Thanks very much for your prompt feedback @mattjj! This helps and we really appreciate it. :+1:

kratsg commented 5 years ago

I think the issue is just 64bit vs 32bit.

I knew about this policy. I'm curious how we still managed to get 64-bit precision on the numpy side with 32b arrays? Is there a way to force 32b for the optimization then if it's still working in 64b? I would assume setting the array's dtype to 32b would correctly set things up.

    x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')

To enable 64bit values, I ran your script like this:

I'll look into the gotchas as well and run our failing test (more complicated minimization) with the 64b switch and see if that happens to resolve things there.

That project looks really awesome, and we want to make JAX work great with it. Please don't hesitate to open issues like this one as you run into any issues or bugs!

Thanks!

mattjj commented 5 years ago

I'm curious how we still managed to get 64-bit precision on the numpy side with 32b arrays?

My guess is that somewhere NumPy or scipy.optimize.minimize is promoting to 64bit precision here, even if you feed in a 32bit input. Indeed, the output in both cases (with JAX_ENABLE_X64=0) is float64, but perhaps more of the calculations are being done in 32bit precision (namely the evaluation of the objective function) when using jax.numpy without enabling 32bit values.

This experiment seems like some evidence in that direction:

import jax.numpy as np
import numpy as onp
from scipy.optimize import minimize

def run(np):
    def rosen(x):
        return np.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

    x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2], dtype='float32')
    bounds = np.array([[0.7, 1.3]] * 5)
    result = minimize(rosen, x0, method='SLSQP', options={'ftol': 1e-9, 'disp': True})
    print(result.x.dtype)
    print(rosen(result.x).dtype)

run(onp)
run(np)
In [1]: run issue936.py
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 7.96982011054e-11
            Iterations: 29
            Function evaluations: 221
            Gradient evaluations: 29
float64
float64
jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Optimization terminated successfully.    (Exit mode 0)
            Current function value: 848.219909668
            Iterations: 1
            Function evaluations: 7
            Gradient evaluations: 1
float64
float32  # different!
kratsg commented 5 years ago

My guess is that somewhere NumPy or scipy.optimize.minimize is promoting to 64bit precision here, even if you feed in a 32bit input. Indeed, the output in both cases (with JAX_ENABLE_X64=0) is float64, but perhaps more of the calculations are being done in 32bit precision (namely the evaluation of the objective function) when using jax.numpy without enabling 32bit values.

Interesting! This is not obvious behavior... I'll dig into this a bit, thanks!