google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Inconsistent parameter gradient behaviour across pipelines #540

Closed JoeMWatson closed 2 weeks ago

JoeMWatson commented 1 month ago

Hi,

I'm using brax 0.11.0 and am confused about parameter gradients. In this issue from last year, pipelines appeared to be consistent, however I'm looking at parameter gradients w.r.t. link mass and they no longer appear consistent across pipelines

If you run the following script,

import jax
from jax import numpy as jp
from brax.positional import pipeline as positional_pipeline
from brax.generalized import pipeline as generalized_pipeline
from brax.mjx import pipeline as mjx_pipeline
from brax.io import mjcf
ball = """
<mujoco>
  <option gravity="0 0 -9.81" timestep="0.002"/>
  <worldbody>
    <geom name="floor" pos="0 0 0" size="40 40 40" type="plane" friction="1.0"/>
    <body pos="0 0 0.2">
      <joint type="free"/>
      <geom size=".2" mass="1.0" friction="1.0"/>
    </body>
  </worldbody>
</mujoco>
"""
for name, pipeline in {
    "pos": positional_pipeline,
    "gen": generalized_pipeline,
    "mjx": mjx_pipeline,
}.items():
    print(name)
    try:
        def simulation(pipeline, sys, params):
            init_qd = jp.zeros(6).at[0].set(1.0)  # 1m/s in +x
            mass = params
            sys = sys.replace(
                link=sys.link.replace(inertia=sys.link.inertia.replace(mass=mass))
            )
            state = jax.jit(pipeline.init)(sys, sys.init_q, init_qd)
            for i in range(1):
                state = jax.jit(pipeline.step)(sys, state, None)
            return state.qd[0]

        sys = mjcf.loads(ball)
        mass = sys.link.inertia.mass
        params = jp.ones_like(mass)
        x = simulation(pipeline, sys, params)
        grad_x = jax.grad(simulation, argnums=2)(pipeline, sys, params)
        print(f"{x=}, {grad_x=}")
    except Exception as exc:
        print(exc)

the output is

pos
x=Array(0.99439585, dtype=float32), grad_x=Array([-0.00160096], dtype=float32)
gen
reshape total size must be unchanged, got new_sizes (6, 3) (of total size 18) for shape (6,) (of total size 6).
mjx
x=Array(1., dtype=float32), grad_x=Array([0.], dtype=float32)

Should we expect to be able to compute parameter gradients for generalized and MJX pipelines?

btaba commented 2 weeks ago

Hi @JoeMWatson , MJX doesn't use sys.link.inertia, which is presumably why the grad is 0. MJX uses classic MuJoCo fields, so mjx.Model.body_inertia. However, keep in mind that if mass/inertia is changed on the fly, you'll want to recompile the entire model, since there will be other derived fields in the model.