Closed davmre closed 2 years ago
Hi davmre, thanks for using brax and pointing out this weirdness! It looks like jumpy is calling onp.vstack
on a traced array in jumpy.scan
since _in_jit
is False, which causes the error. In other words, a traced array is created but not in a jitted function, so jumpy is assuming it can materialize the traced array with numpy, which it can't. The following two cases would resolve the discrepancy by being explicit about jit/non-jit:
Jit the reset function:
batch_rng = jax.random.split(rng, 64)
batch_state = jax.vmap(jax.jit(env.reset))(batch_rng)
print(batch_state.obs)
or disable jit altogether so that jumpy doesn't get tripped up
batch_rng = jax.random.split(rng, 64)
with jax.disable_jit():
batch_state = jax.vmap(env.reset)(batch_rng)
print(batch_state.obs)
Hope that helps a bit, but I think this may be intended; @erikfrey may have more to add here
@davmre I'll close the issue for now since I think the behavior is intended due to the jumpy library, but please feel free to re-ope if there are any other issues!
@btaba @erikfrey Jitting the reset function doesn't work. Any idea how to get this to work?
@jamesheald sorry to hear you're having trouble with jax.jit(reset), can you give a repro?
It is the same issue as the one raised by the original poster. But as with the issue I encountered here (https://github.com/google/brax/issues/321), it goes away if I install brax from source instead of via 'pip install brax'; there seem to be multiple issues with brax when installed via 'pip install brax'. If I can do anything to help you identify and correct the issues, let me know.
Hi, thanks for all your work with Brax! I'm just getting into it, and am running into surprising behavior on a simple example:
(runnable version here: https://colab.research.google.com/drive/1c8jDkShgSRdeBp9fcC9R3PTfMRNbjOlj?usp=sharing)
Expected: prints a batch of initial observations of shape
(64, 4)
.Actual: raises a
TracerArrayConversionError
inside of abrax.jumpy.scan
call:A similar error arises when
env.step
is wrapped invmap
(see notebook link above).I understand that the recommended approach to work with batches of states is to pass the
batch_size
argument toenvs.create
, (which does work), but I would also have expectedenv.reset
andenv.step
to transform and compose like normal JAX functions. Is this behavior intended?