google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

2-link chain instability in backend="positional" #342

Closed namheegordonkim closed 1 year ago

namheegordonkim commented 1 year ago

Minimal example: https://colab.research.google.com/drive/13v6rmV1SWx85m70HPxqMFsoc_hCze9W8?usp=sharing

I have 2-link chain where the "shoulder" joint suspends the link in the air and the "elbow" joint is actuated to the target angle.

Using position control mode in conjunction with self.sys.dof.limits, I am able to make the elbow bend to the max angle as specified in the XML (see example).

However, this seems to be stable enough only when I use backend="generalized". When I use backend="positional", the simulation completely breaks down, and I can't put my finger on anything obvious that I'm doing wrong.

In fact, even when I disable the elbow joint's actuator and just make the 2-link pendulum swing in the air, positional doesn't seem to like it.

Any help would be appreciated. Thanks!

namheegordonkim commented 1 year ago

Ahh, answering my own question here: I thought the XML's "iterations" attribute in

 def __init__(
            self,
            backend='generalized',
            **kwargs,
    ):
        path = "path/to/xml.xml"
        sys = mjcf.load(path)
        n_frames = 32
        sys = sys.replace(dt=sys.dt / n_frames)
        kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
        super().__init__(sys=sys, backend=backend, **kwargs)

This makes the instability go away. Hope this helps someone down the line!

btaba commented 1 year ago

Hi @namheegordonkim , sorry I missed this. Besides the solver timestep (which you can modify in the XML via the timestep attribute), positional instability can also come from mass/inertia. In which case you can directly modify the mass/inertia via the inertial tag or set spring_mass_scale/spring_inertia_scale which scales the mass/inertia