Closed Theo-Cheynel closed 1 year ago
Hi @Theo-Cheynel , are you able to do something like this in the env.reset:
mocap = jp.ones((10, 3)) * jp.arange(0, 10)[:, None] # load this once
rng = jax.random.PRNGKey(0)
def reset(rng):
rng, key = jax.random.split(rng, 2)
return rng, jax.random.choice(key, mocap, (1,))
rng, val = jax.jit(reset)(rng)
Hi,
I want to train an agent to imitate reference motions from a motion capture dataset. I wrote a custom environment, which works well when the reference motion capture clip is "hardcoded" (when it is the same throughout all environments). However, I would like to make the reference clip vary across envs, in other words,I want the
env.reset
to pick a random motion clip (at the moment, motion clips are obtained from another class'__getitem__
method).The thing is, that would make the reset function an impure function, because its outputs would differ everytime, and JAX jitting only supports pure functions. Is there a workaround you can think of ?
At the very least, is it possible to tell
ppo.train
not to jit the reset function, while still jitting the step function ?Thanks for your help