google-deepmind / mujoco

Multi-Joint dynamics with Contact. A general purpose physics simulator.
https://mujoco.org
Apache License 2.0
7.92k stars 794 forks source link

Support for robots with different DoFs in Mujoco XLA #1194

Closed MasterXiong closed 1 month ago

MasterXiong commented 10 months ago

Hi,

Thanks for developing the amazing mujoco v3 with JAX support!

I'm working on a scenario where I want to train a single RL policy on multiple robots with different number of limbs and joints in parallel, which means that the state and action space are diverse across robots. I want to accelerate the training process with JAX, but am not sure if JAX can work under this heterogeneous setting.

We can align the dimension of state and action by zero-padding, but the underlying simulation still runs differently across robots. Based on my understanding, JAX could not support different simulation steps like in this case. But just wanted to check that if my understanding is correct, or Mujoco XLA could actually support parallel training across these different robots? Thanks!

erikfrey commented 10 months ago

Hello,

Do these robots interact with eachother? If so, you may put them in the same scene. If not, you may step multiple scenes in the same environment step:

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    robot_1_data = mjx.step(self.robot_1_model, state.robot_1_data, action[self.robot_1_action_idxs])
    robot_2_data = mjx.step(self.robot_2_model, state.robot_2_data, action[self.robot_2_action_idxs])
    ...
    obs = self._get_obs(robot_1_data, robot_2_data, ...)

This is all doable in JAX - Does that make sense?

MasterXiong commented 10 months ago

Hi @erikfrey , Thanks for your help!

There is no interaction between the robots, and running multiple scenes in the same environment makes sense to me. But I was wondering that whether the step functions of each robot are run in parallel or sequentially (as shown in your example code)? I want to transfer to JAX to speed up my training process, so if each robot's step function runs in sequence, I'm afraid that there may not be much speedup compared to parallelizing the simulation over multiple cpu cores?

oliverweissl commented 7 months ago

@MasterXiong Im currently having a similar issue, where i want to simulate multiple different robots in parallel. Did you find a solution to your problem?

erikfrey commented 7 months ago

Hi @oliverweissl and @MasterXiong - operation fusion in XLA is quite powerful, so I would suggest trying to JIT two distinct mjx.step functions together and see what happens. I suspect you will see lower latency than if you JIT compiled them separately and ran one after the other, but I don't know for sure.

Your other option is to merge the two robots into a single MJCF, so their ctrl would be concatenated and you would have a single mjx.step - but if I were to bet, I would say the former will be faster. That said I'm often surprised by XLA performance quirks, so if it's easy to try both, please do. I would be curious what you learn.