Closed jamesheald closed 1 year ago
Hi @jamesheald can you try running state = jax.jit(env.reset)(rng=jp.random_prngkey(seed=0))
before calling jit_env_step
?
when I run
state = jax.jit(env.reset)(rng=jp.random_prngkey(seed=0))
I get what looks like the same type of error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 235, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/api.py", line 442, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 515, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 967, in _pjit_jaxpr
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/pjit.py", line 925, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 127, in reset
state = self.env.reset(rng)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 100, in reset
state = self.env.reset(rng)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/reacher.py", line 163, in reset
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 190, in default_qp
_, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in scan
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/tree_util.py", line 209, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/tree_util.py", line 209, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in <lambda>
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "<__array_function__ internals>", line 200, in stack
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in stack
arrays = [asanyarray(arr) for arr in arrays]
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/jax/_src/core.py", line 575, in __array__
raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 127, in reset
state = self.env.reset(rng)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py", line 100, in reset
state = self.env.reset(rng)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/reacher.py", line 163, in reset
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/physics/system.py", line 190, in default_qp
_, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in scan
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/jumpy.py", line 117, in <lambda>
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "<__array_function__ internals>", line 200, in stack
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in stack
arrays = [asanyarray(arr) for arr in arrays]
File "/Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /Users/James/opt/anaconda3/envs/hMPC/lib/python3.9/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
If I install brax from source (previously I was using pip install brax), the error goes away, so this is no longer a problem for me, though the underlying issue remains.
can you try state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
Hi all,
I have the exact same problem.
The code:
from brax import envs
from brax import jumpy as jp
import jax
from jax import numpy as jnp
environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jp.random_prngkey(seed=0))
jit_env_step = jax.jit(env.step)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
state = jit_env_step(state, jnp.ones((env.action_size,)))
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
does not solve the issue, and neither does state = jax.jit(env.reset)(rng=jp.random_prngkey(seed=0))
before calling jit_env_step
.
Traceback:
/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
jax.tree_util.register_keypaths(data_clz, keypaths)
/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
jax.tree_util.register_keypaths(data_clz, keypaths)
/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
jax.tree_util.register_keypaths(data_clz, keypaths)
Traceback (most recent call last):
File "/home/a/MISC/brax/test.py", line 11, in <module>
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 235, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/api.py", line 442, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 515, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 967, in _pjit_jaxpr
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/pjit.py", line 925, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 127, in reset
state = self.env.reset(rng)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 100, in reset
state = self.env.reset(rng)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/humanoid.py", line 230, in reset
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/physics/system.py", line 190, in default_qp
_, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in scan
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/tree_util.py", line 209, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/tree_util.py", line 209, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in <lambda>
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "<__array_function__ internals>", line 200, in stack
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in stack
arrays = [asanyarray(arr) for arr in arrays]
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/jax/_src/core.py", line 575, in __array__
raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/a/MISC/brax/test.py", line 11, in <module>
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 127, in reset
state = self.env.reset(rng)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py", line 100, in reset
state = self.env.reset(rng)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/humanoid.py", line 230, in reset
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/physics/system.py", line 190, in default_qp
_, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in scan
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/jumpy.py", line 117, in <lambda>
stacked_y = jax.tree_util.tree_map(lambda *y: onp.stack(y),
File "<__array_function__ internals>", line 200, in stack
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in stack
arrays = [asanyarray(arr) for arr in arrays]
File "/home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function reset at /home/a/mambaforge/envs/BRAX/lib/python3.8/site-packages/brax/envs/wrappers.py:126 for jit. This concrete value was not available in Python because it depends on the value of the argument rng.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Thank you
I opened the brax basics colab and after the pip install, I ran:
from brax.v1 import envs
from brax.v1 import jumpy as jp
import jax
from jax import numpy as jnp
environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jp.random_prngkey(seed=0))
jit_env_step = jax.jit(env.step)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
state = jit_env_step(state, jnp.ones((env.action_size,)))
and cannot reproduce the issue. Plmk if that works
(also note that brax.v1
is being deprecated)
Yes it seems the issue arises when you import packages (e.g. envs) from brax (having installed brax via pip) but not when you import from brax.v1.
I did the same thing with brax v2 (load basics colab, and after pip install):
from brax import envs
import jax
from jax import numpy as jp
environment = "reacher"
env = envs.create(env_name=environment)
state = env.reset(rng=jax.random.PRNGKey(seed=0))
jit_env_step = jax.jit(env.step)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
state = jit_env_step(state, jp.ones((env.action_size,)))
and that works as well, can this issue be closed?
Sure if it's been resolved.
I am new to brax. When I try to perform an environment step using a jitted version of env.step (using code taken from the brax environments notebook), I get a TracerArrayConversionError. Any idea how I can resolve this?
Code:
Traceback:
I am using brax 0.1.1 and jax 0.4.6.
Thanks.