google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Manually Set Initial Pose of Agent in the Reset Function #440

Closed SumeetBatra closed 4 months ago

SumeetBatra commented 5 months ago

Hi folks,

I am using brax 0.9.4 and jax 0.4.23. I have a wrapper that I'm using to manually set the initial xy position of the agent with the following methods:

    def set_initial_state(self, init_pos: jax.Array):
        self.init_pos = self.init_pos.at[:].set(init_pos)

    def reset(self, rng: jax.Array) -> State:
        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale
        q = self.q
        q = q.at[:2].set(self.init_pos)
        q += jax.random.uniform(
            rng1, (self.sys.q_size(),), minval=low, maxval=hi
        )
        qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

        pipeline_state = self.pipeline_init(q, qd)
        obs = self._get_obs(pipeline_state)

Essentially replacing the default torso xy-pose with the desired one by calling set_initial_state followed by env.reset(). This works the first time I run it, but subsequent calls result in the same pose being set as the first time I called set_initial_state.

With the debugger, I am able to step through all the way back to this wrapper the first time env.reset() is called, but after that I can't step in past VectorGymWrapper's reset. My guess is that this has something to do with VectorGymWrapper being jitted? But I'm not sure. Any help would be appreciated!

btaba commented 4 months ago

Hi @SumeetBatra , sorry for the very late reply here. Indeed the first time the code is run (in a jax jit), the value in self.init_pose gets compiled. Subsequent changes to self.init_pose won't affect the compiled code. See here for more info.

To set a specific pose in the reset, consider passing in init_pose to the reset function (i.e. make init_pose a static arg). Alternatively, re-jit the reset function after changing init_pose