Open samuela opened 5 years ago
Same applies if the optimizer functions are not wrapped in a tuple, ie. passing in both opt_update
and get_params
separately to ddpg_episode
. I should also add that ddpg_episode
works fine, albeit slowly, without jit
ing.
After currying the functions out of the jitted function, I'm seeing a new but equally inscrutable error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in run_module(mod_name, init_globals, run_name, alter_sys)
203 run_name = mod_name
204 if alter_sys:
--> 205 return _run_module_code(code, init_globals, run_name, mod_spec)
206 else:
207 # Leave the sys module alone
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_module_code(code, init_globals, mod_name, mod_spec, pkg_name, script_name)
94 mod_globals = temp_module.module.__dict__
95 _run_code(code, mod_globals, init_globals,
---> 96 mod_name, mod_spec, pkg_name, script_name)
97 # Copy the globals of the temporary module, as they
98 # may be cleared when the temporary module goes away
/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/runpy.py in _run_code(code, run_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
83 __package__ = pkg_name,
84 __spec__ = mod_spec)
---> 85 exec(code, run_globals)
86 return run_globals
87
~/dev/research/research/estop/ddpg_pendulum.py in <module>()
78 opt_state,
79 tracking_params,
---> 80 episode_length,
81 )
82 print(f"Episode {epsiode}, reward = {reward}")
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
121 _check_args(args_flat)
122 flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123 out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
124 return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
125
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
656 if top_trace is None:
657 with new_sublevel():
--> 658 ans = primitive.impl(f, *args, **params)
659 else:
660 tracers = map(top_trace.full_raise, args)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
651 def xla_call_impl(fun, *args, **params):
652 device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 653 compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
654 try:
655 return compiled_fun(*args)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
206 if len(cache) > max_size:
207 cache.popitem(last=False)
--> 208 ans = call(f, *args)
209 cache[key] = (ans, f)
210 return ans
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
664 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
665 with core.new_master(pe.JaxprTrace, True) as master:
--> 666 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
667 assert not env # no subtraces here (though cond might eventually need them)
668 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
145
146 del gen
--> 147 ans = self.f(*args, **dict(self.params, **kwargs))
148 del args
149 while stack:
~/dev/research/research/estop/ddpg.py in run(rng, init_replay_buffer, batch_size, init_opt_state, init_tracking_params, epside_length)
173 ) -> LoopState:
174 rng_start, rng_rest = random.split(rng)
--> 175 rngs = random.split(rng_rest, epside_length)
176
177 def step(i, loop_state: LoopState):
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/random.py in split(key, num)
170 An array with shape (num, 2) and dtype uint32 representing `num` new keys.
171 """
--> 172 return _split(key, num)
173
174 @partial(jit, static_argnums=(1,))
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
121 _check_args(args_flat)
122 flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123 out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
124 return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
125
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
659 else:
660 tracers = map(top_trace.full_raise, args)
--> 661 ans = full_lower(top_trace.process_call(primitive, f, tracers, params))
662 return apply_todos(env_trace_todo(), ans)
663
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
111 in_pvs, in_consts = unzip2([t.pval for t in tracers])
112 fun, aux = partial_eval(f, self, in_pvs)
--> 113 out_pv_const, consts = call_primitive.bind(fun, *in_consts, **params)
114 out_pv, jaxpr, env = aux()
115 const_tracers = map(self.new_instantiated_const, consts)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
656 if top_trace is None:
657 with new_sublevel():
--> 658 ans = primitive.impl(f, *args, **params)
659 else:
660 tracers = map(top_trace.full_raise, args)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
651 def xla_call_impl(fun, *args, **params):
652 device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 653 compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
654 try:
655 return compiled_fun(*args)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
206 if len(cache) > max_size:
207 cache.popitem(last=False)
--> 208 ans = call(f, *args)
209 cache[key] = (ans, f)
210 return ans
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
664 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
665 with core.new_master(pe.JaxprTrace, True) as master:
--> 666 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
667 assert not env # no subtraces here (though cond might eventually need them)
668 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
145
146 del gen
--> 147 ans = self.f(*args, **dict(self.params, **kwargs))
148 del args
149 while stack:
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/random.py in _split(key, num)
174 @partial(jit, static_argnums=(1,))
175 def _split(key, num):
--> 176 counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
177 return lax.reshape(threefry_2x32(key, counts), (num, 2))
178
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/lax/lax.py in iota(dtype, size)
928 operator.
929 """
--> 930 return broadcasted_iota(dtype, (int(size),), 0)
931
932 def broadcasted_iota(dtype, shape, dimension):
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/core.py in __int__(self)
337 def __bool__(self): return self.aval._bool(self)
338 def __float__(self): return self.aval._float(self)
--> 339 def __int__(self): return self.aval._int(self)
340 def __long__(self): return self.aval._long(self)
341 def __complex__(self): return self.aval._complex(self)
~/.local/share/virtualenvs/research-CJWjSlE2/lib/python3.7/site-packages/jax/abstract_arrays.py in error(self, *args)
36 def concretization_function_error(fun):
37 def error(self, *args):
---> 38 raise TypeError(concretization_err_msg(fun))
39 return error
40
TypeError: Abstract value passed to `int`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
What is an abstract value? What is a concrete value? Where is this bad thing happening? It looks like random.split may be the culprit but it's not super clear to me why this would be an issue.
Ok, after aggressively currying things out of the jitted function, I now have things working! So perhaps a more appropriate issue here would be a request for better error messages and documentation surrounding what can and cannot be jitted, why certain things work while others don't.
It would also be great to have some specification of what can be expected from using jit
. What sort of optimizations does it do, and what sort of performance improvements should be expected or not expected? My current mental model is that I write code, wrap it in jit
, and things magically go faster. I have no idea if I'm actually writing code that fully utilizes XLA. Perhaps the 10x speedup I observe after mindlessly applying @jit
would be dwarfed by a potential 100x speed improvement from some slightly different code that is amenable to a different set of XLA optimizations.
Thanks for raising this and the detailed notes.
Regarding the latter error message, have you already read the How it works and What's supported sections of the readme, and the Gotchas notebook, especially the Control Flow section? If so, it'd be useful to think through together how they could be improved (i.e. what they're missing), and if not it'd be useful to figure out how to make them more discoverable!
I remember also we wrote a bit more in a comment on #196, which I thought we also linked from the readme but apparently we don't.
Regarding the former error, actually functions can't be arguments to jit
-ed functions. As it says in the jit
docstring, "Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof." Other than adding more information to the "valid jaxtype" error message, do you have suggestions for how to improve the documentation?
(Valid jaxtypes are in practice numpy.ndarray and our effective subclasses. There's also a tuple type but user code doesn't use it. The functions in api.py can accept pytrees of jaxtypes, which in practice effectively means tuple/list/dict-trees of arrays.)
So far here are some action items coming out of this issue:
What other items can we come up with?
I've read through the README, including the "How it works" and "What's supported" sections but in the midst of my debugging it didn't seem as though they directly applied to the situation I faced since neither directly addresses passing functions into jitted functions. I actually mistook the nudge towards FP in the "How it works" section to be an implicit promotion of the use of functions as values (including in jitted code). It took a minute to realize that the appropriate paradigm was to jit closures that have necessary functions in their environments.
I've looked through the Gotchas notebook as well and the section on control flow is interesting but because I was not using any Python control flow in my function it didn't really pattern match with my cmd-f-ing. I also didn't think of the Gotchas notebook when debugging the random.split
issue since those implementation details are internal to jax, and random.split
is not discussed in the notebook.
The terms "abstract value" and "concrete value" aren't defined in the docs AFAIK, although I can make educated guesses as to what they mean. Anytime I see "undefined terms" in an error message, my brain just jumps to the conclusion that it's some sort of internal error.
Yeah, I think writing down a clear definition of a "valid jaxtype" would be very helpful. Even something simple like a tables of things that are and aren't jaxtypes, and things that can and cannot be passed to jitted functions would be great.
I should add that overall the JAX documentation is great considering the age of the project! Sometimes it's just difficult to find the right information at the right time as an external user.
+1 the suggestion to clearly define "valid JAX type" using that phrase to make it emerge in search results. I currently cannot find such a definition when Googling "valid JAX type".
I did find this documentation, which appears helpful with errors related to using invalid JAX types as input.
tl;dr
Better error messages and documentation for what can and cannot be
jit
ed would be great. Current behavior is "black box." See https://github.com/google/jax/issues/953#issuecomment-507119726.See also:
Original issue
I have a particular function that I believe should be
jit
-able but I'm getting errors hitting it withjit
. I have a snipped of code that looks like this:but I get an error:
It's not clear to me what this error message is trying to communicate. What exactly is a valid JAX type? And why is this particular function rejecting while there are plenty of
jit
examples that include functions as arguments?