google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.33k stars 255 forks source link

Cannot `vmap` methods of Brax environments #205

Closed davmre closed 2 years ago

davmre commented 2 years ago

Hi, thanks for all your work with Brax! I'm just getting into it, and am running into surprising behavior on a simple example:

import jax
from brax import envs

env = envs.create('inverted_pendulum')  # Other envs give similar errors.

batch_rng = jax.random.split(jax.random.PRNGKey(0), 64)
batch_state = jax.vmap(env.reset)(batch_rng)
print(batch_state.obs)

(runnable version here: https://colab.research.google.com/drive/1c8jDkShgSRdeBp9fcC9R3PTfMRNbjOlj?usp=sharing)

Expected: prints a batch of initial observations of shape (64, 4).

Actual: raises a TracerArrayConversionError inside of a brax.jumpy.scan call:

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in vmap_f(*args, **kwargs)
   1473         lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
-> 1474     ).call_wrapped(*args_flat)
   1475     return tree_unflatten(out_tree(), out_flat)

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[/usr/local/lib/python3.7/dist-packages/brax/envs/wrappers.py](https://localhost:8080/#) in reset(self, rng)
     86   def reset(self, rng: jp.ndarray) -> brax_env.State:
---> 87     state = self.env.reset(rng)
     88     state.info['first_qp'] = state.qp

[/usr/local/lib/python3.7/dist-packages/brax/envs/wrappers.py](https://localhost:8080/#) in reset(self, rng)
     65   def reset(self, rng: jp.ndarray) -> brax_env.State:
---> 66     state = self.env.reset(rng)
     67     state.info['steps'] = jp.zeros(())

[/usr/local/lib/python3.7/dist-packages/brax/envs/inverted_pendulum.py](https://localhost:8080/#) in reset(self, rng)
    134 
--> 135     qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
    136     obs = self._get_obs(qp, self.sys.info(qp))

[/usr/local/lib/python3.7/dist-packages/brax/physics/system.py](https://localhost:8080/#) in default_qp(self, default_index, joint_angle, joint_velocity)
    170       xs = (joint_angle, joint_velocity, joint_rot, joint_ref)
--> 171       _, (local_rot, local_ang) = jp.scan(local_rot_ang, (), xs, len(joint))
    172 

[/usr/local/lib/python3.7/dist-packages/brax/jumpy.py](https://localhost:8080/#) in scan(f, init, xs, length, reverse, unroll)
    116       ys.append(y)
--> 117     stacked_y = jax.tree_map(lambda *y: onp.vstack(y), *maybe_reversed(ys))
    118     return carry, stacked_y

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
    183   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    185 

[/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
    183   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    185 

[/usr/local/lib/python3.7/dist-packages/brax/jumpy.py](https://localhost:8080/#) in <lambda>(*y)
    116       ys.append(y)
--> 117     stacked_y = jax.tree_map(lambda *y: onp.vstack(y), *maybe_reversed(ys))
    118     return carry, stacked_y

<__array_function__ internals> in vstack(*args, **kwargs)

[/usr/local/lib/python3.7/dist-packages/numpy/core/shape_base.py](https://localhost:8080/#) in vstack(tup)
    278         _arrays_for_stack_dispatcher(tup, stacklevel=2)
--> 279     arrs = atleast_2d(*tup)
    280     if not isinstance(arrs, list):

<__array_function__ internals> in atleast_2d(*args, **kwargs)

[/usr/local/lib/python3.7/dist-packages/numpy/core/shape_base.py](https://localhost:8080/#) in atleast_2d(*arys)
    120     for ary in arys:
--> 121         ary = asanyarray(ary)
    122         if ary.ndim == 0:

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
    536   def __array__(self, *args, **kw):
--> 537     raise TracerArrayConversionError(self)
    538 

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)> with
  val = [...]

A similar error arises when env.step is wrapped in vmap (see notebook link above).

I understand that the recommended approach to work with batches of states is to pass the batch_size argument to envs.create, (which does work), but I would also have expected env.reset and env.step to transform and compose like normal JAX functions. Is this behavior intended?

btaba commented 2 years ago

Hi davmre, thanks for using brax and pointing out this weirdness! It looks like jumpy is calling onp.vstack on a traced array in jumpy.scan since _in_jit is False, which causes the error. In other words, a traced array is created but not in a jitted function, so jumpy is assuming it can materialize the traced array with numpy, which it can't. The following two cases would resolve the discrepancy by being explicit about jit/non-jit:

Jit the reset function:

batch_rng = jax.random.split(rng, 64)
batch_state = jax.vmap(jax.jit(env.reset))(batch_rng)
print(batch_state.obs)

or disable jit altogether so that jumpy doesn't get tripped up

batch_rng = jax.random.split(rng, 64)
with jax.disable_jit():
  batch_state = jax.vmap(env.reset)(batch_rng)
  print(batch_state.obs)

Hope that helps a bit, but I think this may be intended; @erikfrey may have more to add here

btaba commented 2 years ago

@davmre I'll close the issue for now since I think the behavior is intended due to the jumpy library, but please feel free to re-ope if there are any other issues!

jamesheald commented 1 year ago

@btaba @erikfrey Jitting the reset function doesn't work. Any idea how to get this to work?

btaba commented 1 year ago

@jamesheald sorry to hear you're having trouble with jax.jit(reset), can you give a repro?

jamesheald commented 1 year ago

It is the same issue as the one raised by the original poster. But as with the issue I encountered here (https://github.com/google/brax/issues/321), it goes away if I install brax from source instead of via 'pip install brax'; there seem to be multiple issues with brax when installed via 'pip install brax'. If I can do anything to help you identify and correct the issues, let me know.