google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Reacher Task and AutoResetWrapper #443

Closed esraaelelimy closed 4 months ago

esraaelelimy commented 5 months ago

I am looking to use Brax Reacher task as an alternative to Mujoco Reacher for some RL tasks, but I have some concerns: In Mujoco Reacher task , if the fingertip reaches the target, a new random target appears. Also, at the beginning of each new episode, the target position changes. In Brax, I see that the target position is only generated when the environment is rested. Moreover, when using the Autoreset wrapper, at the reset, it fetches the 'first state,' which means that the random target is generated once at the very beginning, and it never changes. Does this make the Brax version of Reacher easy to solve compared to Mujoco's Reacher? and how can we allow the Autoresetwrapper to actually change the target every reset without sacrificing the speed?

btaba commented 4 months ago

Hi @esraaelelimy , indeed AutoResetWrapper will cache the first_state, but the first_state is sampled with a different rng for each environment. As the number of parallel environments goes up, the diversity of sampled first_states increases. You're right to point out that this is done for performance reasons. To do a reset, you'll have to call reset with a different rng. See the example here:

https://github.com/google/brax/blob/f9a4d73181d699db0fa38b07c5a651f5dc8ee231/brax/training/agents/ppo/train.py#L418-L431

But we have not done in-depth analysis on some of these hyperparameters (i.e. num_resets_per_eval).