google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

How to deal with an impure reset function #300

Closed Theo-Cheynel closed 1 year ago

Theo-Cheynel commented 1 year ago

Hi,

I want to train an agent to imitate reference motions from a motion capture dataset. I wrote a custom environment, which works well when the reference motion capture clip is "hardcoded" (when it is the same throughout all environments). However, I would like to make the reference clip vary across envs, in other words,I want the env.reset to pick a random motion clip (at the moment, motion clips are obtained from another class' __getitem__ method).

The thing is, that would make the reset function an impure function, because its outputs would differ everytime, and JAX jitting only supports pure functions. Is there a workaround you can think of ?

At the very least, is it possible to tell ppo.train not to jit the reset function, while still jitting the step function ?

Thanks for your help

btaba commented 1 year ago

Hi @Theo-Cheynel , are you able to do something like this in the env.reset:

mocap = jp.ones((10, 3)) * jp.arange(0, 10)[:, None]  # load this once
rng = jax.random.PRNGKey(0)

def reset(rng):
  rng, key = jax.random.split(rng, 2)
  return rng, jax.random.choice(key, mocap, (1,))

rng, val = jax.jit(reset)(rng)