google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

weak_type inconsistency causes jit recompilation #451

Closed mblondel closed 1 year ago

mblondel commented 1 year ago

Each solver in JAXopt maintains a state, which contains a bunch of attributes, some of them are scalar-valued (e.g.,stepsize). In some solvers, the state returned by init_state and the state returned by update have inconsistent weak_type for some of these attributes, which triggers a JIT recompilation of update. This can be seen with the code example below (code by @fllinares):

import jax
import jax.numpy as jnp
import jaxopt
import sklearn.datasets
import functools
import time

def make_fun_with_aux(fun, dtype=None, has_aux=False):
  @functools.wraps(fun)
  def wrapper(*args, **kwargs):
    value = fun(*args, **kwargs)
    if dtype is not None:
      value = value.astype(dtype)
    aux = {'plus_one': value + 1.0, 'times_two': 2.0 * value}
    return (value, aux) if has_aux else value
  return wrapper

N_SAMPLES = 100
N_FEATURES = 20
N_CLASSES = 3
N_INFORMATIVE = 5

PARAMS_DTYPE = jnp.bfloat16
FUN_DTYPE = jnp.float32
HAS_AUX = True

data = sklearn.datasets.make_classification(n_samples=N_SAMPLES,
                                            n_features=N_FEATURES,
                                            n_classes=N_CLASSES,
                                            n_informative=N_INFORMATIVE,
                                            random_state=0)

init_params = jnp.zeros([N_FEATURES, N_CLASSES], dtype=PARAMS_DTYPE)
fun = make_fun_with_aux(
    jaxopt.objective.multiclass_logreg, dtype=FUN_DTYPE, has_aux=HAS_AUX)

solver = jaxopt.LBFGS(fun=fun, stepsize=1e-2, has_aux=HAS_AUX)
#solver = jaxopt.GradientDescent(fun=fun, stepsize=1e-2, has_aux=HAS_AUX)

update = jax.jit(solver.update)

data = jax.tree_map(jax.device_put, data)
state0 = jax.jit(solver.init_state)(init_params, data=data)
params0 = init_params

print("state0.stepsize.weak_type", state0.stepsize.weak_type)
print()

tic = time.time()
params1, state1 = update(params0, state0, data)
print('First call:', time.time() - tic)
print("state1.stepsize.weak_type", state1.stepsize.weak_type)
print()

tic = time.time()
params2, state2 = update(params1, state1, data)
print('Second call:', time.time() - tic)
print("state2.stepsize.weak_type", state2.stepsize.weak_type)
print()

tic = time.time()
params3, state3 = update(params2, state2, data)
print('Third call:', time.time() - tic)
print("state3.stepsize.weak_type", state3.stepsize.weak_type)
print()

tic = time.time()
params4, state4 = update(params3, state3, data)
print("state4.stepsize.weak_type", state4.stepsize.weak_type)
print('Fourth call:', time.time() - tic)

Output:

$ python recompilation_issue.py
state0.stepsize.weak_type False

First call: 0.5013151168823242
state1.stepsize.weak_type True

Second call: 0.463458776473999
state2.stepsize.weak_type True

Third call: 0.00026679039001464844
state3.stepsize.weak_type True

state4.stepsize.weak_type True
Fourth call: 0.0002238750457763672

What's happening: state0 is obtained from theinit_state call and is given as input to update. A first JIT compilation happens, withstate0.stepsize.weak_type = False. Then update outputs a new state state1 with state1.stepsize.weak_type = True. When we use that state as input to update, a JIT recompilation happens since weak_type has changed. For the following calls to update, no recompilation occurs, since weak_type remains True.

Similarly to the dtype and aux consistency checks in common_test.py, we need to check weak_type consistency for each solver in a systematic manner, and make fixes if necessary.

@froystig Your opinion on the best way to fix would be welcome.

jakevdp commented 1 year ago

Similarly to the dtype and aux consistency checks in common_test.py, we need to check weak_type consistency for each solver in a systematic manner, and make fixes if necessary.

This sounds like the right fix to me. A change in weak type for a function input is like changing the dtype: it can change the function's behavior, and must trigger a re-compilation. If you want to avoid recompilation, you need to make sure the inputs of the second function call match the inputs of the first function call.

mblondel commented 1 year ago

Thanks @jakevdp! I think I tracked down the inconsistency to this behavior:

>>> a = jnp.asarray(0.0, dtype=jnp.float32)
>>> a.weak_type
False
>>> a.dtype
dtype('float32')
>>> b = jnp.asarray(0.0)
>>> b.weak_type
True
>>> b.dtype
dtype('float32')

That is, when we explicitly specify dtype, weak_type is False, while weak_type is True if we don't explicitly specify dtype...

jakevdp commented 1 year ago

Yes, that's expected. Roughly, the mental model of "weak type" is that it's a value whose dtype has not been specified by the user. It's the mechanism that allows (x + 1).dtype == x.dtype to hold true within JAX code.

mblondel commented 1 year ago

fixed by #458