google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

TracerArrayConversionError when jitting env.step #321

Closed jamesheald closed 1 year ago

jamesheald commented 1 year ago

I am new to brax. When I try to perform an environment step using a jitted version of env.step (using code taken from the brax environments notebook), I get a TracerArrayConversionError. Any idea how I can resolve this?

Code:

from brax import envs
from brax import jumpy as jp
import jax
from jax import numpy as jnp

environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jp.random_prngkey(seed=0))

jit_env_step = jax.jit(env.step)
state = jit_env_step(state, jnp.ones((env.action_size,)))

Traceback:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/api.py", line 442, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 515, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 967, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 925, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 138, in step
    state = self.env.step(state, action)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 110, in step
    state, rewards = jp.scan(f, state, (), self.action_repeat)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 115, in scan
    carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 107, in f
    nstate = self.env.step(state, action)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/reacher.py", line 177, in step
    qp, info = self.sys.step(state.qp, action)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 247, in step
    return step_funs[self.config.dynamics_mode](qp, act)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 324, in _pbd_step
    (qp, info), _ = jp.scan(substep, (qp, info), (), self.config.substeps // 2)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 115, in scan
    carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 268, in substep
    dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 268, in <listcomp>
    dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/actuators.py", line 65, in apply
    dang_p, dang_c = jp.vmap(type(self).apply_reduced)(self, act, qp_p, qp_c)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 87, in _batched
    rets.append(fun(*b_args))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/actuators.py", line 107, in apply_reduced
    torque = jp.sum(jp.vmap(jp.multiply)(axis, torque), axis=0)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 89, in _batched
    return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/tree_util.py", line 209, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/tree_util.py", line 209, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 89, in <lambda>
    return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
  File "<__array_function__ internals>", line 200, in stack
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/core.py", line 575, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function step at /Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py:132 for jit. This concrete value was not available in Python because it depends on the values of the arguments state.qp.rot and action.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 138, in step
    state = self.env.step(state, action)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 110, in step
    state, rewards = jp.scan(f, state, (), self.action_repeat)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 115, in scan
    carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 107, in f
    nstate = self.env.step(state, action)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/reacher.py", line 177, in step
    qp, info = self.sys.step(state.qp, action)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 247, in step
    return step_funs[self.config.dynamics_mode](qp, act)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 324, in _pbd_step
    (qp, info), _ = jp.scan(substep, (qp, info), (), self.config.substeps // 2)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 115, in scan
    carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 268, in substep
    dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 268, in <listcomp>
    dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/actuators.py", line 65, in apply
    dang_p, dang_c = jp.vmap(type(self).apply_reduced)(self, act, qp_p, qp_c)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 87, in _batched
    rets.append(fun(*b_args))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/actuators.py", line 107, in apply_reduced
    torque = jp.sum(jp.vmap(jp.multiply)(axis, torque), axis=0)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 89, in _batched
    return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 89, in <lambda>
    return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
  File "<__array_function__ internals>", line 200, in stack
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function step at /Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py:132 for jit. This concrete value was not available in Python because it depends on the values of the arguments state.qp.rot and action.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

I am using brax 0.1.1 and jax 0.4.6.

Thanks.

btaba commented 1 year ago

Hi @jamesheald can you try running state = jax.jit(env.reset)(rng=jp.random_prngkey(seed=0)) before calling jit_env_step?

jamesheald commented 1 year ago

when I run

state = jax.jit(env.reset)(rng=jp.random_prngkey(seed=0))

I get what looks like the same type of error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/api.py", line 442, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 515, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 967, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 925, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 127, in reset
    state = self.env.reset(rng)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 100, in reset
    state = self.env.reset(rng)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/reacher.py", line 163, in reset
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 190, in default_qp
    _, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in scan
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/tree_util.py", line 209, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/tree_util.py", line 209, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in <lambda>
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "<__array_function__ internals>", line 200, in stack
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/core.py", line 575, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 127, in reset
    state = self.env.reset(rng)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 100, in reset
    state = self.env.reset(rng)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/reacher.py", line 163, in reset
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 190, in default_qp
    _, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in scan
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in <lambda>
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "<__array_function__ internals>", line 200, in stack
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
jamesheald commented 1 year ago

If I install brax from source (previously I was using pip install brax), the error goes away, so this is no longer a problem for me, though the underlying issue remains.

btaba commented 1 year ago

can you try state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

alexandre-eymael commented 1 year ago

Hi all,

I have the exact same problem.

The code:

from brax import envs
from brax import jumpy as jp
import jax
from jax import numpy as jnp

environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jp.random_prngkey(seed=0))

jit_env_step = jax.jit(env.step)

state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

state = jit_env_step(state, jnp.ones((env.action_size,)))

state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) does not solve the issue, and neither does state = jax.jit(env.reset)(rng=jp.random_prngkey(seed=0)) before calling jit_env_step.

Traceback:

/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
Traceback (most recent call last):
  File "/home/a/MISC/brax/test.py", line 11, in <module>
    state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/api.py", line 442, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 515, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 967, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 925, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 127, in reset
    state = self.env.reset(rng)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 100, in reset
    state = self.env.reset(rng)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/humanoid.py", line 230, in reset
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/physics/system.py", line 190, in default_qp
    _, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in scan
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/tree_util.py", line 209, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/tree_util.py", line 209, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in <lambda>
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "<__array_function__ internals>", line 200, in stack
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/core.py", line 575, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/a/MISC/brax/test.py", line 11, in <module>
    state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 127, in reset
    state = self.env.reset(rng)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 100, in reset
    state = self.env.reset(rng)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/humanoid.py", line 230, in reset
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/physics/system.py", line 190, in default_qp
    _, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in scan
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in <lambda>
    stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
  File "<__array_function__ internals>", line 200, in stack
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Thank you

btaba commented 1 year ago

I opened the brax basics colab and after the pip install, I ran:

from brax.v1 import envs
from brax.v1 import jumpy as jp
import jax
from jax import numpy as jnp

environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jp.random_prngkey(seed=0))

jit_env_step = jax.jit(env.step)

state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

state = jit_env_step(state, jnp.ones((env.action_size,)))

and cannot reproduce the issue. Plmk if that works

(also note that brax.v1 is being deprecated)

jamesheald commented 1 year ago

Yes it seems the issue arises when you import packages (e.g. envs) from brax (having installed brax via pip) but not when you import from brax.v1.

btaba commented 1 year ago

I did the same thing with brax v2 (load basics colab, and after pip install):

from brax import envs
import jax
from jax import numpy as jp

environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jax.random.PRNGKey(seed=0))

jit_env_step = jax.jit(env.step)

state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

state = jit_env_step(state, jp.ones((env.action_size,)))

and that works as well, can this issue be closed?

jamesheald commented 1 year ago

Sure if it's been resolved.