Closed Daffan closed 8 months ago
Hi @Daffan , yep this is more or less supported via jax.grad
, but YMMV. Here's an example with friction, adapted from here:
import jax
from jax import numpy as jp
from brax.positional import pipeline
from brax.io import mjcf
ball = """
<mujoco>
<option gravity="0 0 -9.81" timestep="0.002"/>
<worldbody>
<geom name="floor" pos="0 0 0" size="40 40 40" type="plane" friction="1.0"/>
<body pos="0 0 0.2">
<joint type="free"/>
<geom size=".2" mass="1.0" friction="1.0"/>
</body>
</worldbody>
</mujoco>
"""
def simulation(pipeline, sys, friction):
init_qd = jp.zeros(6).at[0].set(1.0) # 1m/s in +x
sys.geoms[0] = sys.geoms[0].replace(friction=jp.array([friction]))
state = jax.jit(pipeline.init)(system, sys.init_q, init_qd)
for i in range(50):
state = jax.jit(pipeline.step)(sys, state, None)
return state.qd[0]
sys = mjcf.loads(ball)
x = simulation(pipeline, sys, 2.0)
grad_x = jax.grad(simulation, argnums=2)(pipeline, sys, 2.0)
print(f"{x=}, {grad_x=}")
Does Brax support computing gradient w.r.t. the physical parameters, like friction, stiffness, damping, et. al. Thanks!