google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

how to do n controlled physic steps per every control step #500

Closed salehrayan closed 23 hours ago

salehrayan commented 3 months ago

I want to do e.g. 10 "controlled" steps per every policy action output, that is, I compute 10 motor targets for the next 10 physic steps for every action that the policy sends. I put my opt_timestep=0.002 and n_frames=1.

based on the MJX example, I tried to do it by modifing the pipeline_step. The original:

  def pipeline_step(self, pipeline_state: Any, action: jax.Array) -> base.State:
    """Takes a physics step using the physics pipeline."""

    def f(state, _):
      return (
          self._pipeline.step(self.sys, state, action, self._debug),
          None,
      )

    return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]

and mine:

    def pipeline_step(self, pipeline_state: Any, motor_targets_times: jax.Array) -> base.State:
        """Takes a physics step using the physics pipeline."""

        def f(state, motor_targets):
            return (
              self._pipeline.step(self.sys, state, motor_targets, self._debug),
              None,
            )

        return jax.lax.scan(f, pipeline_state, motor_targets_times)[0]

The leading axis of motor_targets_times is 10. This seems to work, however it is introducing a notable speed decrease about 1.5x~2x.

Don't know why this is happening. Is this expected or am I doing something wrong, and is there a better way to do this?

btaba commented 23 hours ago

FWIU, the way to do "10 repeated actions" is to use action_repeat with the EpisodeWrapper.

Closing this to create a discussion thread