Closed hdadong closed 2 weeks ago
We primarily chose JAX for its speed. Model and policy optimization in JAX was over 75% faster than in PyTorch. Also, we crucially needed to be able to differentiate through the Lagrangian dynamics to train the model - it was super easy to use the tools from Brax to do this.
As for simulation for forward rollouts, we chose Brax since we were already using it for the Lagrangian dynamics. You can also use IsaacSim and Isaac Gym for the forward rollouts with this code. You would just need to make an environment which wraps IsaacSim/Isaac Gym -- the step function would need to convert the JAX arrays to pytorch tensors, call the step function of IsaacSim/Isaac Gym, then convert the returned data back to JAX arrays. This function would not be jit-able though.
Thanks for your reply.
Have you considered using DiffRL (based on pytorch), which is based on PyTorch, to implement differentiable dynamics? Given that the original implementation of MBPO is in PyTorch, and PyTorch is compatible with IsaacSim and Isaac Gym, what led you to choose JAX instead of implementing all components in PyTorch? What were your initial considerations at the start of the project, and did you face any challenges specifically related to using PyTorch?