Closed wbrenton closed 1 year ago
Hi @wbrenton , it looks like an issue here online_muzero/brax/muzero.py:110
from your traceback. We don't generate the argument mzs
. Let us know if there is more info to help us sort out your issue
Hi! I'm having a similar issue (not related to muzero, just with basic brax+jax).
When I run the following in my environment (Python 3.10, brax 0.1.1, jax 0.4.4), I get an error
import jax
import jax.numpy as jnp
import brax
import brax.envs
env = brax.envs.create(env_name="reacher")
state = env.reset(rng=jax.random.PRNGKey(0))
def step(state, _):
next_state = env.step(state, jnp.ones((env.action_size,)))
return next_state, next_state
_, rollout = jax.lax.scan(step, state, None, length=10)
File "<python-path>/site-packages/brax/envs/wrappers.py", line 138, in step
state = self.env.step(state, action)
File "<python-path>/site-packages/brax/envs/wrappers.py", line 110, in step
state, rewards = jp.scan(f, state, (), self.action_repeat)
File "<python-path>/site-packages/brax/jumpy.py", line 115, in scan
carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
File "<python-path>/site-packages/brax/envs/wrappers.py", line 107, in f
nstate = self.env.step(state, action)
File "<python-path>/site-packages/brax/envs/reacher.py", line 177, in step
qp, info = self.sys.step(state.qp, action)
File "<python-path>/site-packages/brax/physics/system.py", line 247, in step
return step_funs[self.config.dynamics_mode](qp, act)
File "<python-path>/site-packages/brax/physics/system.py", line 324, in _pbd_step
(qp, info), _ = jp.scan(substep, (qp, info), (), self.config.substeps // 2)
File "<python-path>/site-packages/brax/jumpy.py", line 115, in scan
carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
File "<python-path>/site-packages/brax/physics/system.py", line 268, in substep
dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
File "<python-path>/site-packages/brax/physics/system.py", line 268, in <listcomp>
dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
File "<python-path>/site-packages/brax/physics/actuators.py", line 65, in apply
dang_p, dang_c = jp.vmap(type(self).apply_reduced)(self, act, qp_p, qp_c)
File "<python-path>/site-packages/brax/jumpy.py", line 87, in _batched
rets.append(fun(*b_args))
File "<python-path>/site-packages/brax/physics/actuators.py", line 107, in apply_reduced
torque = jp.sum(jp.vmap(jp.multiply)(axis, torque), axis=0)
File "<python-path>/site-packages/brax/jumpy.py", line 89, in _batched
return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
File "<python-path>/site-packages/brax/jumpy.py", line 89, in <lambda>
return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
File "<__array_function__ internals>", line 200, in stack
File "<python-path>/site-packages/numpy/core/shape_base.py", line 458, in stack
arrays = [asanyarray(arr) for arr in arrays]
File "<python-path>/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function step at <stdin>:1 for scan. This concrete value was not available in Python because it depends on the value of the argument 'state'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Interestingly, when I run it on the "Brax Environments.ipynb" notebook (brax 0.1.1, jax 0.3.25, python 3.8, not sure if the versions are relevant), I don't get an error.
FWIW it's not a version issue; tested downgrading my JAX version to match the nb environment and the issue persists. Another difference between the nb and local is that I'm not running on a GPU locally.
I also tested with python 3.9 and it still fails.
More debugging results: it fails regardless of whether I jax.jit
the step
function, the entire scan
, or neither. It also fails with a similar error if I try jax.jit(env.reset)
Another update (sorry for the comment spam): this issue does not occur when I build brax from main
, so maybe I just need to cool my jets and wait for the next release :grin:
@dawsonc I found that using brax.v2.envs resolved the issue as well as building from source.
@wbrenton , glad your issue is resolved!
Hi @dawsonc , since you're using old brax in your example (not brax.v2
), this should fix your error (brax.v2 won't be using jumpy so we won't have similar confusion around numpy/jax.numpy arrays):
from brax import jumpy as jp
_, rollout = jp.scan(step, state, None, length=10)
Thanks! I'll probably switch to v2
to avoid the issue (I'll need grad
in addition to scan
so I don't want to just side-step the issue with jumpy
).
Is this issue fixed? I seem to have a few issues
When using v2 I get
Python 3.8.16 | packaged by conda-forge | (default, Feb 1 2023, 16:01:55)
[GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import brax.v2.envs
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/stao/miniconda3/envs/robojax/lib/python3.8/site-packages/brax/v2/envs/__init__.py", line 20, in <module>
from brax.v2.envs import ant
File "/home/stao/miniconda3/envs/robojax/lib/python3.8/site-packages/brax/v2/envs/ant.py", line 21, in <module>
from brax.v2.io import mjcf
ModuleNotFoundError: No module named 'brax.v2.io'
When using v1, I have the same error as OP when using any wrapper around reset functions (e.g. the VectorGymWrapper).
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
Another note, is there docs on what brax v2 is? I wasn't aware this happened until recently and some of my code is failing.
Hi @StoneT2000 ! The brax.v2.io import should be fixed, let us know if you have any issues with imports in a new issue!
FYI we just moved brax.v2
to brax
. Here is a reference to the v2 announcement https://github.com/google/brax/discussions/286. The current README in the repo is about brax v2
Thanks! I'll take a look today
@btaba code is working again, thanks! Although now image is no longer a part of brax's v2 io module?
At some point when executing self.sys.step (line 224 of ant), brax tries to convert a traced array into a original numpy array, which JAX does not allow. I'm having a difficult time finding where exactly this function call is being made.
Is it possible this is a bug? Any assistance would be greatly appriciated!
Here is the error message for reference.