jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.29k stars 2.78k forks source link

Improved error messages and documentation regarding what can/cannot be `jit`ed #953

Open samuela opened 5 years ago

samuela commented 5 years ago

tl;dr

Better error messages and documentation for what can and cannot be jited would be great. Current behavior is "black box." See https://github.com/google/jax/issues/953#issuecomment-507119726.

See also:

Original issue

I have a particular function that I believe should be jit-able but I'm getting errors hitting it with jit. I have a snipped of code that looks like this:

from jax.experimental import optimizers

OptState = TypeVar("OptState")

class Optimizer(NamedTuple):
  init: Callable[[Any], OptState]
  update: Callable[[int, Any, OptState], OptState]
  get: Callable[[OptState], Any]

def ddpg_episode(
    optimizer: Optimizer,
    ...
) -> LoopState:
  ...

optimizer = Optimizer(*optimizers.adam(step_size=1e-3))
jit(ddpg_episode)(
    optimizer,
    ...
)

but I get an error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in run_module(mod_name, init_globals, run_name, alter_sys)
    203         run_name = mod_name
    204     if alter_sys:
--> 205         return _run_module_code(code, init_globals, run_name, mod_spec)
    206     else:
    207         # Leave the sys module alone

/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_module_code(code, init_globals, mod_name, mod_spec, pkg_name, script_name)
     94         mod_globals = temp_module.module.__dict__
     95         _run_code(code, mod_globals, init_globals,
---> 96                   mod_name, mod_spec, pkg_name, script_name)
     97     # Copy the globals of the temporary module, as they
     98     # may be cleared when the temporary module goes away

/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_code(code, run_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
     83                        __package__ = pkg_name,
     84                        __spec__ = mod_spec)
---> 85     exec(code, run_globals)
     86     return run_globals
     87 

~/dev/research/research/estop/ddpg_pendulum.py in <module>()
     73       critic,
     74       episode_length,
---> 75       noise,
     76   )
     77   print(f"Episode {epsiode}, reward = {reward}")

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    119     f, dyn_args = _argnums_partial(f, dyn_argnums, args)
    120     args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
    123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in _check_args(args)
    944     if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
    945       raise TypeError("Argument '{}' of type {} is not a valid JAX type"
--> 946                       .format(arg, type(arg)))
    947 
    948 def _valid_jaxtype(arg):

TypeError: Argument '<function adam.<locals>.init at 0x11b176c80>' of type <class 'function'> is not a valid JAX type

It's not clear to me what this error message is trying to communicate. What exactly is a valid JAX type? And why is this particular function rejecting while there are plenty of jit examples that include functions as arguments?

samuela commented 5 years ago

Same applies if the optimizer functions are not wrapped in a tuple, ie. passing in both opt_update and get_params separately to ddpg_episode. I should also add that ddpg_episode works fine, albeit slowly, without jiting.

samuela commented 5 years ago

After currying the functions out of the jitted function, I'm seeing a new but equally inscrutable error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in run_module(mod_name, init_globals, run_name, alter_sys)
    203         run_name = mod_name
    204     if alter_sys:
--> 205         return _run_module_code(code, init_globals, run_name, mod_spec)
    206     else:
    207         # Leave the sys module alone

/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_module_code(code, init_globals, mod_name, mod_spec, pkg_name, script_name)
     94         mod_globals = temp_module.module.__dict__
     95         _run_code(code, mod_globals, init_globals,
---> 96                   mod_name, mod_spec, pkg_name, script_name)
     97     # Copy the globals of the temporary module, as they
     98     # may be cleared when the temporary module goes away

/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_code(code, run_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
     83                        __package__ = pkg_name,
     84                        __spec__ = mod_spec)
---> 85     exec(code, run_globals)
     86     return run_globals
     87 

~/dev/research/research/estop/ddpg_pendulum.py in <module>()
     78       opt_state,
     79       tracking_params,
---> 80       episode_length,
     81   )
     82   print(f"Episode {epsiode}, reward = {reward}")

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    124     return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
    125 

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    656   if top_trace is None:
    657     with new_sublevel():
--> 658       ans = primitive.impl(f, *args, **params)
    659   else:
    660     tracers = map(top_trace.full_raise, args)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
    651 def xla_call_impl(fun, *args, **params):
    652   device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 653   compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    654   try:
    655     return compiled_fun(*args)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
    664   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    665   with core.new_master(pe.JaxprTrace, True) as master:
--> 666     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    667     assert not env  # no subtraces here (though cond might eventually need them)
    668     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    145 
    146     del gen
--> 147     ans = self.f(*args, **dict(self.params, **kwargs))
    148     del args
    149     while stack:

~/dev/research/research/estop/ddpg.py in run(rng, init_replay_buffer, batch_size, init_opt_state, init_tracking_params, epside_length)
    173   ) -> LoopState:
    174     rng_start, rng_rest = random.split(rng)
--> 175     rngs = random.split(rng_rest, epside_length)
    176 
    177     def step(i, loop_state: LoopState):

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/random.py in split(key, num)
    170     An array with shape (num, 2) and dtype uint32 representing `num` new keys.
    171   """
--> 172   return _split(key, num)
    173 
    174 @partial(jit, static_argnums=(1,))

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    124     return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
    125 

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    659   else:
    660     tracers = map(top_trace.full_raise, args)
--> 661     ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
    662   return apply_todos(env_trace_todo(), ans)
    663 

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    111     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    112     fun, aux = partial_eval(f, self, in_pvs)
--> 113     out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
    114     out_pv, jaxpr, env = aux()
    115     const_tracers = map(self.new_instantiated_const, consts)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    656   if top_trace is None:
    657     with new_sublevel():
--> 658       ans = primitive.impl(f, *args, **params)
    659   else:
    660     tracers = map(top_trace.full_raise, args)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
    651 def xla_call_impl(fun, *args, **params):
    652   device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 653   compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    654   try:
    655     return compiled_fun(*args)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
    664   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    665   with core.new_master(pe.JaxprTrace, True) as master:
--> 666     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    667     assert not env  # no subtraces here (though cond might eventually need them)
    668     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    145 
    146     del gen
--> 147     ans = self.f(*args, **dict(self.params, **kwargs))
    148     del args
    149     while stack:

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/random.py in _split(key, num)
    174 @partial(jit, static_argnums=(1,))
    175 def _split(key, num):
--> 176   counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
    177   return lax.reshape(threefry_2x32(key, counts), (num, 2))
    178 

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/lax/lax.py in iota(dtype, size)
    928   operator.
    929   """
--> 930   return broadcasted_iota(dtype, (int(size),), 0)
    931 
    932 def broadcasted_iota(dtype, shape, dimension):

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in __int__(self)
    337   def __bool__(self): return self.aval._bool(self)
    338   def __float__(self): return self.aval._float(self)
--> 339   def __int__(self): return self.aval._int(self)
    340   def __long__(self): return self.aval._long(self)
    341   def __complex__(self): return self.aval._complex(self)

~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/abstract_arrays.py in error(self, *args)
     36 def concretization_function_error(fun):
     37   def error(self, *args):
---> 38     raise TypeError(concretization_err_msg(fun))
     39   return error
     40 

TypeError: Abstract value passed to `int`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

What is an abstract value? What is a concrete value? Where is this bad thing happening? It looks like random.split may be the culprit but it's not super clear to me why this would be an issue.

samuela commented 5 years ago

Ok, after aggressively currying things out of the jitted function, I now have things working! So perhaps a more appropriate issue here would be a request for better error messages and documentation surrounding what can and cannot be jitted, why certain things work while others don't.

It would also be great to have some specification of what can be expected from using jit. What sort of optimizations does it do, and what sort of performance improvements should be expected or not expected? My current mental model is that I write code, wrap it in jit, and things magically go faster. I have no idea if I'm actually writing code that fully utilizes XLA. Perhaps the 10x speedup I observe after mindlessly applying @jit would be dwarfed by a potential 100x speed improvement from some slightly different code that is amenable to a different set of XLA optimizations.

mattjj commented 5 years ago

Thanks for raising this and the detailed notes.

Regarding the latter error message, have you already read the How it works and What's supported sections of the readme, and the Gotchas notebook, especially the Control Flow section? If so, it'd be useful to think through together how they could be improved (i.e. what they're missing), and if not it'd be useful to figure out how to make them more discoverable!

I remember also we wrote a bit more in a comment on #196, which I thought we also linked from the readme but apparently we don't.

Regarding the former error, actually functions can't be arguments to jit-ed functions. As it says in the jit docstring, "Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof." Other than adding more information to the "valid jaxtype" error message, do you have suggestions for how to improve the documentation?

(Valid jaxtypes are in practice numpy.ndarray and our effective subclasses. There's also a tuple type but user code doesn't use it. The functions in api.py can accept pytrees of jaxtypes, which in practice effectively means tuple/list/dict-trees of arrays.)

mattjj commented 5 years ago

So far here are some action items coming out of this issue:

  1. Improve the "valid jaxtypes" error message to explain that jaxtypes are arrays.
  2. Make our current (minimal) explanations of JAX abstract interpretation machinery more discoverable (they may be hidden near the bottom of a long readme), since hopefully those provide a conceptual framework for understanding error messages.

What other items can we come up with?

samuela commented 5 years ago

I've read through the README, including the "How it works" and "What's supported" sections but in the midst of my debugging it didn't seem as though they directly applied to the situation I faced since neither directly addresses passing functions into jitted functions. I actually mistook the nudge towards FP in the "How it works" section to be an implicit promotion of the use of functions as values (including in jitted code). It took a minute to realize that the appropriate paradigm was to jit closures that have necessary functions in their environments.

I've looked through the Gotchas notebook as well and the section on control flow is interesting but because I was not using any Python control flow in my function it didn't really pattern match with my cmd-f-ing. I also didn't think of the Gotchas notebook when debugging the random.split issue since those implementation details are internal to jax, and random.split is not discussed in the notebook.

The terms "abstract value" and "concrete value" aren't defined in the docs AFAIK, although I can make educated guesses as to what they mean. Anytime I see "undefined terms" in an error message, my brain just jumps to the conclusion that it's some sort of internal error.

Yeah, I think writing down a clear definition of a "valid jaxtype" would be very helpful. Even something simple like a tables of things that are and aren't jaxtypes, and things that can and cannot be passed to jitted functions would be great.

I should add that overall the JAX documentation is great considering the age of the project! Sometimes it's just difficult to find the right information at the right time as an external user.

wbradknox commented 2 years ago

+1 the suggestion to clearly define "valid JAX type" using that phrase to make it emerge in search results. I currently cannot find such a definition when Googling "valid JAX type".

I did find this documentation, which appears helpful with errors related to using invalid JAX types as input.