google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

brax trying to convert traced array into numpy.ndarray #306

Closed wbrenton closed 1 year ago

wbrenton commented 1 year ago

At some point when executing self.sys.step (line 224 of ant), brax tries to convert a traced array into a original numpy array, which JAX does not allow. I'm having a difficult time finding where exactly this function call is being made.

Is it possible this is a bug? Any assistance would be greatly appriciated!

Here is the error message for reference.

The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=4/0)>
The error occurred while tracing the function step_fn at /Users/will/podracers/online_muzero/brax/muzero.py:110 for scan. This concrete value was not available in Python because it depends on the values of the arguments 'mzs' and 'rng'.
btaba commented 1 year ago

Hi @wbrenton , it looks like an issue here online_muzero/brax/muzero.py:110 from your traceback. We don't generate the argument mzs. Let us know if there is more info to help us sort out your issue

dawsonc commented 1 year ago

Hi! I'm having a similar issue (not related to muzero, just with basic brax+jax).

When I run the following in my environment (Python 3.10, brax 0.1.1, jax 0.4.4), I get an error

import jax
import jax.numpy as jnp

import brax
import brax.envs

env = brax.envs.create(env_name="reacher")
state = env.reset(rng=jax.random.PRNGKey(0))

def step(state, _):
  next_state = env.step(state, jnp.ones((env.action_size,)))
  return next_state, next_state

_, rollout = jax.lax.scan(step, state, None, length=10)
File "<python-path>/site-packages/brax/envs/wrappers.py", line 138, in step
    state = self.env.step(state, action)
  File "<python-path>/site-packages/brax/envs/wrappers.py", line 110, in step
    state, rewards = jp.scan(f, state, (), self.action_repeat)
  File "<python-path>/site-packages/brax/jumpy.py", line 115, in scan
    carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
  File "<python-path>/site-packages/brax/envs/wrappers.py", line 107, in f
    nstate = self.env.step(state, action)
  File "<python-path>/site-packages/brax/envs/reacher.py", line 177, in step
    qp, info = self.sys.step(state.qp, action)
  File "<python-path>/site-packages/brax/physics/system.py", line 247, in step
    return step_funs[self.config.dynamics_mode](qp, act)
  File "<python-path>/site-packages/brax/physics/system.py", line 324, in _pbd_step
    (qp, info), _ = jp.scan(substep, (qp, info), (), self.config.substeps // 2)
  File "<python-path>/site-packages/brax/jumpy.py", line 115, in scan
    carry, y = f(carry, jax.tree_util.tree_unflatten(xs_tree, xs_slice))
  File "<python-path>/site-packages/brax/physics/system.py", line 268, in substep
    dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
  File "<python-path>/site-packages/brax/physics/system.py", line 268, in <listcomp>
    dp_a = sum([a.apply(qp, act) for a in self.actuators], zero)
  File "<python-path>/site-packages/brax/physics/actuators.py", line 65, in apply
    dang_p, dang_c = jp.vmap(type(self).apply_reduced)(self, act, qp_p, qp_c)
  File "<python-path>/site-packages/brax/jumpy.py", line 87, in _batched
    rets.append(fun(*b_args))
  File "<python-path>/site-packages/brax/physics/actuators.py", line 107, in apply_reduced
    torque = jp.sum(jp.vmap(jp.multiply)(axis, torque), axis=0)
  File "<python-path>/site-packages/brax/jumpy.py", line 89, in _batched
    return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
  File "<python-path>/site-packages/brax/jumpy.py", line 89, in <lambda>
    return jax.tree_util.tree_map(lambda *x: onp.stack(x), *rets)
  File "<__array_function__ internals>", line 200, in stack
  File "<python-path>/site-packages/numpy/core/shape_base.py", line 458, in stack
    arrays = [asanyarray(arr) for arr in arrays]
  File "<python-path>/site-packages/numpy/core/shape_base.py", line 458, in <listcomp>
    arrays = [asanyarray(arr) for arr in arrays]
jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function step at <stdin>:1 for scan. This concrete value was not available in Python because it depends on the value of the argument 'state'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Interestingly, when I run it on the "Brax Environments.ipynb" notebook (brax 0.1.1, jax 0.3.25, python 3.8, not sure if the versions are relevant), I don't get an error.

dawsonc commented 1 year ago

FWIW it's not a version issue; tested downgrading my JAX version to match the nb environment and the issue persists. Another difference between the nb and local is that I'm not running on a GPU locally.

I also tested with python 3.9 and it still fails.

dawsonc commented 1 year ago

More debugging results: it fails regardless of whether I jax.jit the step function, the entire scan, or neither. It also fails with a similar error if I try jax.jit(env.reset)

dawsonc commented 1 year ago

Another update (sorry for the comment spam): this issue does not occur when I build brax from main, so maybe I just need to cool my jets and wait for the next release :grin:

wbrenton commented 1 year ago

@dawsonc I found that using brax.v2.envs resolved the issue as well as building from source.

btaba commented 1 year ago

@wbrenton , glad your issue is resolved!

Hi @dawsonc , since you're using old brax in your example (not brax.v2), this should fix your error (brax.v2 won't be using jumpy so we won't have similar confusion around numpy/jax.numpy arrays):

from brax import jumpy as jp
_, rollout = jp.scan(step, state, None, length=10)
dawsonc commented 1 year ago

Thanks! I'll probably switch to v2 to avoid the issue (I'll need grad in addition to scan so I don't want to just side-step the issue with jumpy).

StoneT2000 commented 1 year ago

Is this issue fixed? I seem to have a few issues

When using v2 I get

Python 3.8.16 | packaged by conda-forge | (default, Feb  1 2023, 16:01:55) 
[GCC 11.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import brax.v2.envs
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/stao/miniconda3/envs/robojax/lib/python3.8/site-packages/brax/v2/envs/__init__.py", line 20, in <module>
    from brax.v2.envs import ant
  File "/home/stao/miniconda3/envs/robojax/lib/python3.8/site-packages/brax/v2/envs/ant.py", line 21, in <module>
    from brax.v2.io import mjcf
ModuleNotFoundError: No module named 'brax.v2.io'

When using v1, I have the same error as OP when using any wrapper around reset functions (e.g. the VectorGymWrapper).

jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>

Another note, is there docs on what brax v2 is? I wasn't aware this happened until recently and some of my code is failing.

btaba commented 1 year ago

Hi @StoneT2000 ! The brax.v2.io import should be fixed, let us know if you have any issues with imports in a new issue! FYI we just moved brax.v2 to brax. Here is a reference to the v2 announcement https://github.com/google/brax/discussions/286. The current README in the repo is about brax v2

StoneT2000 commented 1 year ago

Thanks! I'll take a look today

StoneT2000 commented 1 year ago

@btaba code is working again, thanks! Although now image is no longer a part of brax's v2 io module?