Closed NeilGirdhar closed 3 months ago
Hmm nothing springs to mind, but if you can share a repro (even a non-minimal one) I would love to take a look.
I think #4008 is in 0.2.4 (see commit 4a20eea in the list, for some reason our new source sync process didn't list it as a merged pr), but you're right that #4595 contains the suspicious part of it (landed separately so as not to break some people).
Okay, I'll try to bisect this myself tonight.
Yup, #4595 breaks it. It works at 23352e76 and fails at 22c3684d. That's lucky because 4595 doesn't have many changed lines. Maybe I could print something out for you in _flatten_bwd
? Another option is just to give you access to my repo? Otherwise, it'll take me a couple days to produce a minimum working example, since I basically have to remove lines (in a giant 3k line project) until it starts working.
Another option is just to give you access to my repo?
Yeah I'm up for giving that a shot! I'm optimistic that I won't need a minimal repro to figure out what's going on.
Okay, I sent you the invite (just run python demos/encoding.py
from within the source tree after installing tjax and efax with pip). I've tracked the problem down to the call to treedef.unflatten
in tree_multimap
called by _flatten_bwd
. It's the unflattening that's raising—not the calls to f
. It's hard for me to debug into unflatten since it looks like it's coming from jaxlib?
Indeed unflatten is in jaxlib, though we could dig up a pure Python version if that's necessary. But I'm optimistic it won't be!
I'm getting this error from the pip version of progressbar:
[*] Inferring
Traceback (most recent call last):
File "demos/encoding.py", line 143, in <module>
encoding_demo()
File "demos/encoding.py", line 109, in encoding_demo
training_pt = solution.train(2000)
File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/solution.py", line 97, in train
with progressbar.ProgressBar(widgets=widgets, max_value=iterations) as progress_bar:
TypeError: __init__() got an unexpected keyword argument 'max_value'
LMK if you have any thoughts on that...
A quick Bing search answered my question!
Hrm still not hitting the right error:
Traceback (most recent call last):
File "demos/encoding.py", line 143, in <module>
encoding_demo()
File "demos/encoding.py", line 109, in encoding_demo
training_pt = solution.train(2000)
File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/solution.py", line 114, in train
update_progress_bar)
File "/usr/local/google/home/mattjj/packages/jax/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/google/home/mattjj/packages/jax/jax/api.py", line 219, in f_jitted
donated_invars=donated_invars)
File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 1174, in bind
return call_bind(self, fun, *args, **params)
File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 1165, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 1177, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 576, in process_call
return primitive.impl(f, *tracers, **params)
File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/xla.py", line 557, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 247, in memoized_fun
ans = call(fun, *args)
File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/xla.py", line 632, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1193, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1174, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/usr/local/google/home/mattjj/packages/tjax/tjax/fixed_point/iterated_function.py", line 102, in sample_trajectory
return scan(f, self.initial_augmented(initial_state), None, iteration_limit)
File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 1251, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 1238, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 72, in _initial_style_jaxpr
jaxpr, out_avals, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 67, in _initial_style_open_jaxpr
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1164, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1174, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/usr/local/google/home/mattjj/packages/tjax/tjax/fixed_point/iterated_function.py", line 97, in f
new_state, trajectory = self.sampled_state_trajectory(theta, augmented)
File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/runner.py", line 66, in sampled_state_trajectory
return self._sampled_state_trajectory(theta, augmented.current_state)
File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/runner.py", line 90, in _sampled_state_trajectory
rl_result = self.rl_inference.infer(state.parameter_states, state.rng)
File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/rl/inference.py", line 92, in infer
rl_state = while_loop(cond_fun, self._body_fun, rl_state)
File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 286, in while_loop
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 272, in _create_jaxpr
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
File "<string>", line 2, in __hash__
File "<string>", line 2, in __hash__
TypeError: unhashable type: 'dict'
Hmmm, try it on Python 3.8? The requirements are listed in "pyproject.toml". I'm running off the master branch and it seems to work for me.
Also, if you point me to the jaxlib code for unflatten, I can also try to figure out why it's raising ValueError
.
Also, lol at working for Google and using Bing. I think I did that once when I was there too :rofl:
I mean I see that treedef.num_leaves
is 33, I put the generator's output into a list l
, which ends up equal to [None] * 33
. How can treedef.unflatten(l)
fail?
Are you using a newer jaxlib? I'm using the latest release: 0.1.56.
I think I figured it out: https://github.com/NeilGirdhar/tjax/blob/5e5ed9823642b662673374108daf7fa11ae781fa/tjax/generator.py is having trouble with being passed None in its initializer. It's raising ValueError
in the new unflattening. How do you think I should correct this? Is it a requirement that all pytree-like objects be default-constructible? I guess I'll replace my constructor with factory methods since JAX seems to have no trouble with regular dataclasses.
Oh, so the try: ... except ValueError: ...
is too broad, and obscuring the real issue? Yeah that sounds very plausible. I should revise this code to be more robust to errors in the pytree flatten/unflatten functions.
Is it a requirement that all pytree-like objects be default-constructible?
No, and I don't think there are any defaults in question here. The assumption here is that pytrees can contain arbitrary Python objects (which is part of them being isomorphic to tuples). I think perhaps one of your pytrees doesn't have that property?
I can probably revise this logic not to rely on that property, though in general it's something we assume about pytrees.
Yeah I needed Python 3.8 to repro!
What I did was I selectively chose various children in tree_multimap
:
leaves, treedef = pytree.flatten(tree)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
if treedef.num_leaves == 33:
cs = treedef.children()
print([c.num_leaves for c in cs]) # prints 11, 11, 11
# the first two are unflattened fine (cs[0].unflatten([None] * 11) and same for 1), so let's look at the third:
the_d = cs[2]
ds = the_d.children()
print([d.num_leaves for d in ds]) # 10, 1
# the first one is unflattened fine
the_e = ds[1]
es = the_e.children()
print(the_e, the_e.num_leaves) # PyTreeDef(<class 'tjax.generator.Generator'>[()], [*]) 1
So the Generator object can't be unflattened. I checked and the ValueError is being raised in Generator's constructor.
I see, yeah. I can think of a quick workaround actually: if there are no Nones returned by the user custom_vjp bwd function, we can bypass this logic and just do the old thing. How does that sound?
Okay, changing the constructor to a classmethod fixes it! Phew. Thanks for staying up with me.
How does that sound?
Don't you think that will be confusing if someone starts returning None and all of sudden they get a very weird ValueError? I would just keep things as they are now and improve the error messages? Maybe replace
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
with
mapped_leaves = f(*xs) for xs in zip(*all_leaves)
try:
return treedef.unflatten(mapped_leaves)
except Exception as e:
# I'm not sure exactly what, but maybe raise a special exception here that contains e inside it?
raise UnflattenException(...)
It might be helpful in UnflattenException to mention that objects must be default constructible. Maybe even mention which object couldn't be constructed?
I feel like I'm pushing the boundaries of JAX in a lot of ways and I've apparently always been able to provide default constructibility. I feel like that's a pretty minor requirement. WDYT?
Don't you think that will be confusing if someone starts returning None and all of sudden they get a very weird ValueError?
Yes certainly, by workaround I meant a temporary fix just to unblock you!
Oh, that's nice of you! I just changed my implementation of Generator to be constructible using None
, so I'm rolling again—no need for any workaround. Thanks again!
I'd like to improve the errors and robustness here, so let me leave this issue open until I do :)
Glad you're unblocked!
Never got back to this one! Let's close it.
I tried updating to master as well, but that didn't help.
I will go through the changelog for 0.2.4, but could someone point me to where something might have broken? I checked that the pytrees are in fact the same. (One pytree contains tracers, the other contains some
object()
, but I think that's a JAX implementation detail.)