google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Compute gradient w.r.t. the physical parameters #399

Closed Daffan closed 8 months ago

Daffan commented 9 months ago

Does Brax support computing gradient w.r.t. the physical parameters, like friction, stiffness, damping, et. al. Thanks!

btaba commented 8 months ago

Hi @Daffan , yep this is more or less supported via jax.grad, but YMMV. Here's an example with friction, adapted from here:

import jax
from jax import numpy as jp
from brax.positional import pipeline
from brax.io import mjcf

ball = """
<mujoco>
  <option gravity="0 0 -9.81" timestep="0.002"/>
  <worldbody>
    <geom name="floor" pos="0 0 0" size="40 40 40" type="plane" friction="1.0"/>
    <body pos="0 0 0.2">
      <joint type="free"/>
      <geom size=".2" mass="1.0" friction="1.0"/>
    </body>
  </worldbody>
</mujoco>
"""

def simulation(pipeline, sys, friction):
    init_qd = jp.zeros(6).at[0].set(1.0)  # 1m/s in +x
    sys.geoms[0] = sys.geoms[0].replace(friction=jp.array([friction]))
    state = jax.jit(pipeline.init)(system, sys.init_q, init_qd)
    for i in range(50):
        state = jax.jit(pipeline.step)(sys, state, None)
    return state.qd[0]

sys = mjcf.loads(ball)
x = simulation(pipeline, sys, 2.0)
grad_x = jax.grad(simulation, argnums=2)(pipeline, sys, 2.0)
print(f"{x=}, {grad_x=}")