jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Upgrading from 0.2.3 seems to break something with custom VJP #4673

Closed NeilGirdhar closed 3 months ago

NeilGirdhar commented 4 years ago

I tried updating to master as well, but that didn't help.

I will go through the changelog for 0.2.4, but could someone point me to where something might have broken? I checked that the pytrees are in fact the same. (One pytree contains tracers, the other contains some object(), but I think that's a JAX implementation detail.)

$ python demos/encoding.py
[*] Inferring...

$ pip install jax==0.2.4
Processing /home/neil/.cache/pip/wheels/ce/f6/a8/2075fcce214c29511994904934727cae7c800b21f48524f673/jax-0.2.4-py3-none-any.whl
Requirement already satisfied: numpy>=1.12 in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from jax==0.2.4) (1.20.0.dev0+0bd548e)
Requirement already satisfied: absl-py in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from jax==0.2.4) (0.10.0)
Requirement already satisfied: opt-einsum in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from jax==0.2.4) (3.3.0)
Requirement already satisfied: six in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from absl-py->jax==0.2.4) (1.15.0)
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.2.3
    Uninstalling jax-0.2.3:
      Successfully uninstalled jax-0.2.3
Successfully installed jax-0.2.4

$ python demos/encoding.py
[*] Inferring...
Traceback (most recent call last):
  File "demos/encoding.py", line 143, in <module>
    encoding_demo()
  File "demos/encoding.py", line 109, in encoding_demo
    training_pt = solution.train(2000)
  File "/home/neil/src/cmm/cmm/structure/solution/solution.py", line 111, in train
    augmented, trajectory = method(None,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 213, in f_jitted
    out = xla.xla_call(
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 1174, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 1165, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 1177, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 576, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/xla.py", line 556, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 247, in memoized_fun
    ans = call(fun, *args)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/xla.py", line 632, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1183, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1164, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/tjax/fixed_point/iterated_function.py", line 102, in sample_trajectory
    return scan(f, self.initial_augmented(initial_state), None, iteration_limit)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1251, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1238, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 72, in _initial_style_jaxpr
    jaxpr, out_avals, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 67, in _initial_style_open_jaxpr
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1154, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1164, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/tjax/fixed_point/iterated_function.py", line 97, in f
    new_state, trajectory = self.sampled_state_trajectory(theta, augmented)
  File "/home/neil/src/cmm/cmm/structure/solution/runner.py", line 66, in sampled_state_trajectory
    return self._sampled_state_trajectory(theta, augmented.current_state)
  File "/home/neil/src/cmm/cmm/structure/solution/runner.py", line 90, in _sampled_state_trajectory
    rl_result = self.rl_inference.infer(state.parameter_states, state.rng)
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 92, in infer
    rl_state = while_loop(cond_fun, self._body_fun, rl_state)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 286, in while_loop
    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 272, in _create_jaxpr
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 72, in _initial_style_jaxpr
    jaxpr, out_avals, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 67, in _initial_style_open_jaxpr
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1154, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1164, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 157, in _body_fun
    weights_bar, primals = f(rl_state.parameter_states.weights)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 756, in grad_f_aux
    (_, aux), g = value_and_grad_f(*args, **kwargs)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 819, in value_and_grad_f
    g = vjp_py(np.ones((), dtype=dtype))
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 1791, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 120, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 220, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 634, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/tjax/shims.py", line 114, in new_bwd
    input_bar = bwd(*static_args, internal_residuals, output_bar)
  File "/home/neil/src/cmm/cmm/encoding/inference.py", line 171, in _infer_encoding_configuration_bwd
    observation_bars, weights_bars = vmapped_f_vjp(encoding_regularizers_bars)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 1230, in batched_fun
    out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/batching.py", line 36, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 1791, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 120, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 220, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 634, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 169, in call_wrapped
    ans = gen.send(ans)
  File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/custom_derivatives.py", line 544, in _flatten_bwd
    raise TypeError(msg.format(in_tree2, in_tree)) from None
TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef(tuple, [PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunction'>[()], [*,*,*,*,*,*,PyTreeDef(<class 'cmm.encoding.element.EncodingElement'>[(Path(('module', 'observation.x', 'encoding')),)], [PyTreeDef(<class 'cmm.pss.space.exponential_family.ExpFamSpace'>[(Path(('module', 'observation.x', 'encoding', 'space')), NormalUnitVariance(shape=(), num_parameters=5))], []),PyTreeDef(<class 'cmm.encoding.cluster.CodeCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.PresenceScoreCluster'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster')),)], [*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), (1,), 'encoding_decay', True)], [])]),PyTreeDef(<class 'cmm.encoding.cluster.ValueCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.MechanicalCluster'>[()], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), (1, 4), 'encoding', True, OuterMatching(shape=[(1,), (4,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), (1, 5), 'encoding', True, OuterMatching(shape=[(1,), (5,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'explanation_link')), (5, 4), 'encoding', False, OuterMatching(shape=[(5,), (4,)]))], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_gating_link')), (4, 1), 'encoding', True, OuterMatching(shape=[(4,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_gating_link')), (5, 1), 'encoding', True, OuterMatching(shape=[(5,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'demand_link')), (4, 5), 'encoding', False, OuterMatching(shape=[(4,), (5,)]))], []),*,*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), (5,), 'deduction', False)], [])])]),PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunctionParameters'>[()], [PyTreeDef(<class 'cmm.pss.observation.Observation'>[()], [*,*]),*,PyTreeDef(<class 'cmm.structure.foundation.parallel.ParallelStructure'>[()], [PyTreeDef(dict[[Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), Path(('module', 'observation.x', 'encoding', 'code_gating_link')), Path(('module', 'observation.x', 'encoding', 'demand_link')), Path(('module', 'observation.x', 'encoding', 'explanation_link')), Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), Path(('module', 'observation.x', 'encoding', 'value_gating_link'))]], [*,*,*,*,*,*,*,*])])]),PyTreeDef(<class 'cmm.encoding.configuration.EncodingState'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingDifferentiand'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingComparand'>[()], [*,*,*,*,*,*,*,*]),*,*]),PyTreeDef(<class 'tjax.generator.Generator'>[()], [*])])]) for primal input structure PyTreeDef(tuple, [PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunction'>[()], [*,*,*,*,*,*,PyTreeDef(<class 'cmm.encoding.element.EncodingElement'>[(Path(('module', 'observation.x', 'encoding')),)], [PyTreeDef(<class 'cmm.pss.space.exponential_family.ExpFamSpace'>[(Path(('module', 'observation.x', 'encoding', 'space')), NormalUnitVariance(shape=(), num_parameters=5))], []),PyTreeDef(<class 'cmm.encoding.cluster.CodeCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.PresenceScoreCluster'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster')),)], [*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), (1,), 'encoding_decay', True)], [])]),PyTreeDef(<class 'cmm.encoding.cluster.ValueCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.MechanicalCluster'>[()], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), (1, 4), 'encoding', True, OuterMatching(shape=[(1,), (4,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), (1, 5), 'encoding', True, OuterMatching(shape=[(1,), (5,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'explanation_link')), (5, 4), 'encoding', False, OuterMatching(shape=[(5,), (4,)]))], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_gating_link')), (4, 1), 'encoding', True, OuterMatching(shape=[(4,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_gating_link')), (5, 1), 'encoding', True, OuterMatching(shape=[(5,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'demand_link')), (4, 5), 'encoding', False, OuterMatching(shape=[(4,), (5,)]))], []),*,*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), (5,), 'deduction', False)], [])])]),PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunctionParameters'>[()], [PyTreeDef(<class 'cmm.pss.observation.Observation'>[()], [*,*]),*,PyTreeDef(<class 'cmm.structure.foundation.parallel.ParallelStructure'>[()], [PyTreeDef(dict[[Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), Path(('module', 'observation.x', 'encoding', 'code_gating_link')), Path(('module', 'observation.x', 'encoding', 'demand_link')), Path(('module', 'observation.x', 'encoding', 'explanation_link')), Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), Path(('module', 'observation.x', 'encoding', 'value_gating_link'))]], [*,*,*,*,*,*,*,*])])]),PyTreeDef(<class 'cmm.encoding.configuration.EncodingState'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingDifferentiand'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingComparand'>[()], [*,*,*,*,*,*,*,*]),*,*]),PyTreeDef(<class 'tjax.generator.Generator'>[()], [*])])]).
mattjj commented 4 years ago

Hmm nothing springs to mind, but if you can share a repro (even a non-minimal one) I would love to take a look.

4008 is the thing that would've broken things, as it involved a significant rewrite. (Sorry, I didn't really update the changelog :P)

NeilGirdhar commented 4 years ago

4008 is in master, but not in 0.2.4 though? And it's broken in 0.2.4. I'm looking at #4595 right now. I'll try backing that out and seeing if it fixes things.

mattjj commented 4 years ago

I think #4008 is in 0.2.4 (see commit 4a20eea in the list, for some reason our new source sync process didn't list it as a merged pr), but you're right that #4595 contains the suspicious part of it (landed separately so as not to break some people).

NeilGirdhar commented 4 years ago

Okay, I'll try to bisect this myself tonight.

NeilGirdhar commented 4 years ago

Yup, #4595 breaks it. It works at 23352e76 and fails at 22c3684d. That's lucky because 4595 doesn't have many changed lines. Maybe I could print something out for you in _flatten_bwd? Another option is just to give you access to my repo? Otherwise, it'll take me a couple days to produce a minimum working example, since I basically have to remove lines (in a giant 3k line project) until it starts working.

mattjj commented 4 years ago

Another option is just to give you access to my repo?

Yeah I'm up for giving that a shot! I'm optimistic that I won't need a minimal repro to figure out what's going on.

NeilGirdhar commented 4 years ago

Okay, I sent you the invite (just run python demos/encoding.py from within the source tree after installing tjax and efax with pip). I've tracked the problem down to the call to treedef.unflatten in tree_multimap called by _flatten_bwd. It's the unflattening that's raising—not the calls to f. It's hard for me to debug into unflatten since it looks like it's coming from jaxlib?

mattjj commented 4 years ago

Indeed unflatten is in jaxlib, though we could dig up a pure Python version if that's necessary. But I'm optimistic it won't be!

I'm getting this error from the pip version of progressbar:

[*] Inferring
Traceback (most recent call last):
  File "demos/encoding.py", line 143, in <module>
    encoding_demo()
  File "demos/encoding.py", line 109, in encoding_demo
    training_pt = solution.train(2000)
  File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/solution.py", line 97, in train
    with progressbar.ProgressBar(widgets=widgets, max_value=iterations) as progress_bar:
TypeError: __init__() got an unexpected keyword argument 'max_value'

LMK if you have any thoughts on that...

mattjj commented 4 years ago

A quick Bing search answered my question!

mattjj commented 4 years ago

Hrm still not hitting the right error:

Traceback (most recent call last):
  File "demos/encoding.py", line 143, in <module>
    encoding_demo()
  File "demos/encoding.py", line 109, in encoding_demo
    training_pt = solution.train(2000)
  File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/solution.py", line 114, in train
    update_progress_bar)
  File "/usr/local/google/home/mattjj/packages/jax/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/google/home/mattjj/packages/jax/jax/api.py", line 219, in f_jitted
    donated_invars=donated_invars)
  File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 1174, in bind
    return call_bind(self, fun, *args, **params)
  File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 1165, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 1177, in process
    return trace.process_call(self, fun, tracers, params)
  File "/usr/local/google/home/mattjj/packages/jax/jax/core.py", line 576, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/xla.py", line 557, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 247, in memoized_fun
    ans = call(fun, *args)
  File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/xla.py", line 632, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1193, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1174, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/google/home/mattjj/packages/tjax/tjax/fixed_point/iterated_function.py", line 102, in sample_trajectory
    return scan(f, self.initial_augmented(initial_state), None, iteration_limit)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 1251, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 1238, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 72, in _initial_style_jaxpr
    jaxpr, out_avals, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 67, in _initial_style_open_jaxpr
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1164, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/usr/local/google/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 1174, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/usr/local/google/home/mattjj/packages/jax/jax/linear_util.py", line 156, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/google/home/mattjj/packages/tjax/tjax/fixed_point/iterated_function.py", line 97, in f
    new_state, trajectory = self.sampled_state_trajectory(theta, augmented)
  File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/runner.py", line 66, in sampled_state_trajectory
    return self._sampled_state_trajectory(theta, augmented.current_state)
  File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/solution/runner.py", line 90, in _sampled_state_trajectory
    rl_result = self.rl_inference.infer(state.parameter_states, state.rng)
  File "/usr/local/google/home/mattjj/packages/cmm/cmm/structure/rl/inference.py", line 92, in infer
    rl_state = while_loop(cond_fun, self._body_fun, rl_state)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 286, in while_loop
    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/lax/control_flow.py", line 272, in _create_jaxpr
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
  File "<string>", line 2, in __hash__
  File "<string>", line 2, in __hash__
TypeError: unhashable type: 'dict'
NeilGirdhar commented 4 years ago

Hmmm, try it on Python 3.8? The requirements are listed in "pyproject.toml". I'm running off the master branch and it seems to work for me.

Also, if you point me to the jaxlib code for unflatten, I can also try to figure out why it's raising ValueError.

NeilGirdhar commented 4 years ago

Also, lol at working for Google and using Bing. I think I did that once when I was there too :rofl:

NeilGirdhar commented 4 years ago

I mean I see that treedef.num_leaves is 33, I put the generator's output into a list l, which ends up equal to [None] * 33. How can treedef.unflatten(l) fail?

NeilGirdhar commented 4 years ago

Are you using a newer jaxlib? I'm using the latest release: 0.1.56.

NeilGirdhar commented 4 years ago

I think I figured it out: https://github.com/NeilGirdhar/tjax/blob/5e5ed9823642b662673374108daf7fa11ae781fa/tjax/generator.py is having trouble with being passed None in its initializer. It's raising ValueError in the new unflattening. How do you think I should correct this? Is it a requirement that all pytree-like objects be default-constructible? I guess I'll replace my constructor with factory methods since JAX seems to have no trouble with regular dataclasses.

mattjj commented 4 years ago

Oh, so the try: ... except ValueError: ... is too broad, and obscuring the real issue? Yeah that sounds very plausible. I should revise this code to be more robust to errors in the pytree flatten/unflatten functions.

Is it a requirement that all pytree-like objects be default-constructible?

No, and I don't think there are any defaults in question here. The assumption here is that pytrees can contain arbitrary Python objects (which is part of them being isomorphic to tuples). I think perhaps one of your pytrees doesn't have that property?

I can probably revise this logic not to rely on that property, though in general it's something we assume about pytrees.

mattjj commented 4 years ago

Yeah I needed Python 3.8 to repro!

NeilGirdhar commented 4 years ago

What I did was I selectively chose various children in tree_multimap:

  leaves, treedef = pytree.flatten(tree)
  all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  if treedef.num_leaves == 33:
      cs = treedef.children()
      print([c.num_leaves for c in cs])  # prints 11, 11, 11
      # the first two are unflattened fine (cs[0].unflatten([None] * 11) and same for 1), so let's look at the third:
      the_d = cs[2]  
      ds = the_d.children()
      print([d.num_leaves for d in ds])  # 10, 1
      # the first one is unflattened fine
      the_e = ds[1]
      es = the_e.children()
      print(the_e, the_e.num_leaves)   # PyTreeDef(<class 'tjax.generator.Generator'>[()], [*]) 1

So the Generator object can't be unflattened. I checked and the ValueError is being raised in Generator's constructor.

mattjj commented 4 years ago

I see, yeah. I can think of a quick workaround actually: if there are no Nones returned by the user custom_vjp bwd function, we can bypass this logic and just do the old thing. How does that sound?

NeilGirdhar commented 4 years ago

Okay, changing the constructor to a classmethod fixes it! Phew. Thanks for staying up with me.

How does that sound?

Don't you think that will be confusing if someone starts returning None and all of sudden they get a very weird ValueError? I would just keep things as they are now and improve the error messages? Maybe replace

  return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

with

  mapped_leaves = f(*xs) for xs in zip(*all_leaves)
  try:
    return treedef.unflatten(mapped_leaves)
  except Exception as e:
    # I'm not sure exactly what, but maybe raise a special exception here that contains e inside it?
    raise UnflattenException(...)

It might be helpful in UnflattenException to mention that objects must be default constructible. Maybe even mention which object couldn't be constructed?

I feel like I'm pushing the boundaries of JAX in a lot of ways and I've apparently always been able to provide default constructibility. I feel like that's a pretty minor requirement. WDYT?

mattjj commented 4 years ago

Don't you think that will be confusing if someone starts returning None and all of sudden they get a very weird ValueError?

Yes certainly, by workaround I meant a temporary fix just to unblock you!

NeilGirdhar commented 4 years ago

Oh, that's nice of you! I just changed my implementation of Generator to be constructible using None, so I'm rolling again—no need for any workaround. Thanks again!

mattjj commented 4 years ago

I'd like to improve the errors and robustness here, so let me leave this issue open until I do :)

Glad you're unblocked!

mattjj commented 3 months ago

Never got back to this one! Let's close it.