google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.06k stars 2.66k forks source link

UnshapedArray has no shape! #13152

Closed NeilGirdhar closed 1 year ago

NeilGirdhar commented 1 year ago

Description

I don't have a MWE, but I can give someone access to my repository if it helps. What I do know is adding a stop-gradient eliminates the crash, so it appears that the cotangent has no shape.

Exception in thread Thread-2:
Traceback (most recent call last):
  File "/home/neil/.pyenv/versions/3.10.4/lib/python3.10/threading.py", line 966, in _bootstrap
    self._bootstrap_inner()
  File "/home/neil/.pyenv/versions/3.10.4/lib/python3.10/threading.py", line 1001, in _bootstrap_inner
    del _limbo[self]
  File "/home/neil/src/cmm/cmm/visualization/plot_thread.py", line 50, in run
    assert self.solver_index is not None
  File "/home/neil/src/cmm/cmm/plot/morphism/plot.py", line 64, in plot
    plot_data = self.extract_data(solver, hard=hard, log=log, progress_manager=progress_manager)
  File "/home/neil/src/cmm/cmm/plot/morphism/plot.py", line 154, in extract_data
    if selected_distillation is None:
  File "/home/neil/src/cmm/cmm/plot/single_episode.py", line 37, in single_episode
    log.error(f"Training example ({self.training_example_index}) larger than maximum "
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 118, in __call__
    raise CacheMiss()
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 98, in create_kwargs
    if value is None:
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 100, in <dictcomp>
    methods = {name: getattr(caching_obj, name)(hard=hard,
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 120, in __call__
    with log.context(f"{self.method.__name__}"):
  File "/home/neil/src/cmm/cmm/structure/solver/solver.py", line 34, in training_result
    return train_episodes(solution, training_segments, batch_size, log=log,
  File "/home/neil/src/cmm/cmm/structure/solution/training.py", line 108, in train_episodes
    train_one_episode = jit(RLInference.train_one_episode, static_argnums=(3,))
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 111, in train_one_episode
    condition_function = partial(self._training_cond_fun, self.problem.max_episode_steps())
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 293, in _training_body_fun
    weights_bar, infer_outputs = self._v_infer_gradient_and_value(
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 225, in _v_infer_gradient_and_value
    f = vmap(self._infer_gradient_and_value,
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 204, in _infer_gradient_and_value
    f = grad(bound_infer, has_aux=True)
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 170, in _infer
    inference_rng, new_inference_rng = split(inference_rng)
  File "/home/neil/src/cmm/cmm/structure/model/model.py", line 171, in infer_one_time_step
    for example_rng_, inference_rng_, (node_name, node) in zip(example_rngs, inference_rngs,
  File "/home/neil/src/cmm/cmm/distillation/node.py", line 32, in infer
    prediction, memory = self._collect_prediction(model, model_weights, model_configuration,
  File "/home/neil/src/cmm/cmm/distillation/base.py", line 96, in _infer_encoding
    inference_result = infer_encoding_configuration(self.instant_encoding, observation,
  File "/home/neil/src/tjax/tjax/_src/shims.py", line 62, in __call__
    return self.vjp(u, *args, **kwargs)
  File "/home/neil/src/cmm/cmm/encoding/inference/configuration.py", line 46, in infer_encoding_configuration_fwd
    internal_result, weight_vjp = vjp(partial(internal_infer_encoding, encoding, observation, rng,
  File "/home/neil/src/cmm/cmm/encoding/inference/internal.py", line 183, in internal_infer_encoding
    geometry_rng, deconfounder_rng, intention_rng = split(rng, 3)
  File "/home/neil/src/cmm/cmm/encoding/inference/internal.py", line 61, in geometry_phase
    def seeker(weights: hk.Params
  File "/home/neil/src/tjax/tjax/_src/cotangent_tools.py", line 99, in _cotangent_combinator_fwd
    return vjp(f, *args_tuples[0])
  File "/home/neil/src/cmm/cmm/encoding/inference/internal.py", line 65, in seeker
    parameters.seeker_iterations, False, weights)
  File "/home/neil/src/cmm/cmm/encoding/inference/seeker.py", line 61, in seeker_inference
    use_code_noise=False,
  File "/home/neil/src/tjax/tjax/_src/fixed_point/base.py", line 49, in sample_trajectory
    def f(augmented: TheAugmentedState, x: None) -> Tuple[TheAugmentedState, Trajectory]:
  File "/home/neil/src/tjax/tjax/_src/fixed_point/base.py", line 51, in f
    new_state, trajectory = self.sampled_state_trajectory(theta, augmented)
  File "/home/neil/src/cmm/cmm/encoding/sampler/sampler.py", line 115, in sampled_state_trajectory
    state = self.sampled_state(theta, augmented)
  File "/home/neil/src/cmm/cmm/encoding/sampler/sampler.py", line 72, in sampled_state
    time_step = phase_parameters.time_step()
  File "/home/neil/src/cmm/cmm/encoding/foundation/rival_message_sp.py", line 46, in to_regular_message
    attention = softplus(self.isp_attention)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/nn/functions.py", line 75, in softplus
    return jnp.logaddexp(x, 0)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: UnshapedArray has no shape. Please open an issue at https://github.com/google/jax/issues because it's unexpected for UnshapedArray instances to ever be produced.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/neil/.pyenv/versions/3.10.4/lib/python3.10/threading.py", line 966, in _bootstrap
    self._bootstrap_inner()
  File "/home/neil/.pyenv/versions/3.10.4/lib/python3.10/threading.py", line 1009, in _bootstrap_inner
    self.run()
  File "/home/neil/src/cmm/cmm/visualization/plot_thread.py", line 54, in run
    self.plotter.plot(self.plot, self.solvers, self.solver_index, hard=self.hard,
  File "/home/neil/src/cmm/cmm/plot/morphism/plot.py", line 64, in plot
    plot_data = self.extract_data(solver, hard=hard, log=log, progress_manager=progress_manager)
  File "/home/neil/src/cmm/cmm/plot/morphism/plot.py", line 157, in extract_data
    internal_result = self.single_episode(solver, hard=hard, log=log,
  File "/home/neil/src/cmm/cmm/plot/single_episode.py", line 45, in single_episode
    return solver.single_episode(distillation_name=self.selected_distillation,
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 121, in __call__
    method_kwargs = self.create_kwargs(self.bound, kwargs, hard, log, progress_manager)
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 100, in create_kwargs
    methods = {name: getattr(caching_obj, name)(hard=hard,
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 100, in <dictcomp>
    methods = {name: getattr(caching_obj, name)(hard=hard,
  File "/home/neil/src/cmm/cmm/structure/solver/cached_method.py", line 123, in __call__
    self.result[value_tuple] = self.method(type(self.bound), **method_kwargs)
  File "/home/neil/src/cmm/cmm/structure/solver/solver.py", line 34, in training_result
    return train_episodes(solution, training_segments, batch_size, log=log,
  File "/home/neil/src/cmm/cmm/structure/solution/training.py", line 116, in train_episodes
    training_result = train_one_episode(solution.rl_inference, data_mapping,
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 620, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 300, in memoized_fun
    ans = call(fun, *args)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2081, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 114, in train_one_episode
    training_state = while_loop(condition_function, body_function, training_state)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1110, in while_loop
    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1093, in _create_jaxpr
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1985, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2002, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 293, in _training_body_fun
    weights_bar, infer_outputs = self._v_infer_gradient_and_value(
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 228, in _v_infer_gradient_and_value
    weights_bars, infer_outputs = f(model_memories, example_rngs, inference_rngs, model_weights)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 1680, in vmap_f
    out_flat = batching.batch(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 205, in _infer_gradient_and_value
    return f(model_weights)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 1095, in grad_f_aux
    (_, aux), g = value_and_grad_f(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 1171, in value_and_grad_f
    g = vjp_py(lax_internal._one(ans))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 2575, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 142, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 248, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 751, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 673, in <lambda>
    bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/src/cmm/cmm/encoding/inference/configuration.py", line 65, in infer_encoding_configuration_bwd
    weights_bar, = residuals.weight_vjp(internal_result_bar)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 2575, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 142, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 248, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 751, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 673, in <lambda>
    bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/src/tjax/tjax/_src/cotangent_tools.py", line 116, in _cotangent_combinator_bwd
    this_args_bar = f_vjp(this_result_bar)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/api.py", line 2575, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 142, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 245, in backward_pass
    cts_out = reducing_transposes[eqn.primitive](
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 674, in _scan_transpose
    jaxpr_trans = _transpose_scan_jaxpr(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 712, in _transpose_scan_jaxpr
    return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 133, in _make_closed_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1985, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2002, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 705, in transposed
    cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts,
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 242, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 627, in call_transpose
    out_flat = primitive.bind(fun, *all_args, **params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/core.py", line 2006, in bind
    outs = top_trace.process_call(self, fun_, tracers, params)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1743, in process_call
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/interpreters/ad.py", line 248, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 2146, in _add_transpose
    return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)]
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 1622, in _unbroadcast
    x_shape = np.shape(x)
  File "<__array_function__ internals>", line 180, in shape
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/numpy/core/fromnumeric.py", line 2007, in shape
    result = a.shape
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/core.py", line 649, in __getattr__
    attr = getattr(self.aval, name)
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.10/lib/python3.10/site-packages/jax/core.py", line 1354, in shape
    raise TypeError(msg)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: UnshapedArray has no shape. Please open an issue at https://github.com/google/jax/issues because it's unexpected for UnshapedArray instances to ever be produced.

What jax/jaxlib version are you using?

Jax master (0.3.25), jaxlib 0.3.24

Which accelerator(s) are you using?

CPU

hawkinsp commented 1 year ago

We're going to need some sort of repro, I think.

NeilGirdhar commented 1 year ago

@hawkinsp Okay, I'll invest the time into that. It means starting with my 8k line program, and chopping it page by page until I have a MWE.

NeilGirdhar commented 1 year ago

After significantly rewriting my code, I can no longer reproduce this.

I think it may have been caused by a custom JVP returning a cotangent with a different pytree structure than the primal. For some reason, Jax isn't catching the structure mismatch, and I had to add some assertions.