google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.35k stars 255 forks source link

Reconstructing scene from qpos and qvel #345

Open namheegordonkim opened 1 year ago

namheegordonkim commented 1 year ago

I've mentioned this briefly in #157, but I believe this deserves a separate thread.

In certain applications like Go-Explore (https://www.nature.com/articles/s41586-020-03157-9), it's critical for the user to be able to reset the environment to the state corresponding to the gathered observation, i.e., qpos and qvel.

In my experiments with Brax v2, I haven't been able to successfully reconstruct the simulator state from qpos and qvel alone--I suspect the 2nd order information is also necessary. According to MuJoCo docs:

class mujoco_py.MjSimState

Represents a snapshot of the simulator’s state.

This includes time, qpos, qvel, act, and udd_state.

Attributes

Methods

Ignoring time and udd_state, act needs to be taken into account. Not sure by what mechanism MuJoCo does this, but we can gather that qpos, qvel, and act should be necessary and sufficient for reconstructing the state.

Current workaround for this is to actually store the states (i.e. in a batched instance of State class) gathered during exploration and to use that as the argument for step(). But this comes at a huge memory overhead.

Any idea whether this will be made possible?

erikfrey commented 1 year ago

Hello! Yes, q and qd (or qpos and qvel in mujoco terms) are all you need to recreate simulator state. That said, for performance reasons, Brax stores other data that it reuses from one step to the next. Knowing this, you have two options:

1) If you wish to store only the minimal state, q and qd, you can always recreate the full state by calling pipeline.init to recreate the desired brax state, then call pipeline.step afterwards.

2) If you can afford to save the full state, then you can skip calling pipeline.init to recreate the state.

namheegordonkim commented 1 year ago

Thanks for the fast response, @erikfrey.

If you wish to store only the minimal state, q and qd, you can always recreate the full state by calling pipeline.init to recreate the desired brax state, then call pipeline.step afterwards.

This has not been the case for me unfortunately--at least in positional backend, using pipeline.init() with q and qd did not result in the same matrices; I checked this by saving tuples of (State, (q, qd)) and comparing entries inside State vs. pipeline.init(q, qd). Would be happy to provide a minimal example when I get the chance.

erikfrey commented 1 year ago

OK! It's our expectation that this would work, so please do share a repro and we'll take a look.

namheegordonkim commented 1 year ago

Here's the minimal example I promised: https://colab.research.google.com/drive/1lAqGR4mpd4EX4eeXhrlR4CzQEoQncbZk?usp=sharing

You will see that state2 is a reconstructed state from state1's q and qd values, but after exact same actions are used for step(), the two trajectories diverge, as visible in the rendered output. backend="positional" also has the same issue.

erwincoumans commented 1 year ago

Can you provide info on exactly what data is reused?

Brax stores other data that it reuses from one step to the next.

namheegordonkim commented 1 year ago

Not to hijack Erwin's question--I further narrowed it down to kinematics.forward().

Repro: https://colab.research.google.com/drive/1P3mtOCNTngZ6A_CJtjJdWzn4zX4TPcda?usp=sharing

Seems that using anything but the default q and qd results in unreproducible resulst of kinematics.forward().

erikfrey commented 1 year ago

@namheegordonkim I think you're onto something here. I agree the assert in your colab shouldn't fire. We'll take a closer look.

@erwincoumans it varies by pipeline. In generalized, for example, we iteratively update the inverse mass matrix from the previous step instead of recalculating it from scratch. We can certainly improve the documentation to make this clearer.

namheegordonkim commented 1 year ago

Great to hear that you'll be looking into this. I've wrestled with it for a few days and identified some sus lines. I hope these help!

https://github.com/google/brax/blob/c2cd14cf762242d63aeec106d955390c8e14d582/brax/kinematics.py#L88-L105

Here, the parent link and the child link are held together by the joint. Initially x is the world coordinate position of the parent link COM, but using x.vmap().do(j) transforms it to the world coordinate position of the child link. The child link inherits the parent's world coordinate angles and velocities, so all that's left (theoretically) is to add the contributions of the angular velocity of the joint.

However, jax.vmap(jp.cross)(x.pos, jd.ang) plainly seems wrong here: the linear velocity contribution via angular velocity should be done as a cross product between angular velocity and moment arm. First, x.pos doesn't give you the local coordinate moment arm; sys.link.joint.pos does. Next, the cross product isn't commutable so shouldn't the order between these two swtiched?

I did fool around with passing sys.link.joint into scan.tree call as another argument, but hadn't had success in replicating the correct linear velocities.

erikfrey commented 1 year ago

Oh whoops, you know, I lied to you. For spring and pbd pipelines, q and qd are not sufficient to reconstruct physics state. Those two pipelines can produce states with joint constraint violations which cannot be expressed in reduced coordinates. That is why everything lines up for you at init, but after a step you start to see the error - no joint constraints are violated at init, but are after a step.

For spring and pbd, you would want to use either x and xd or x_i and xd_i to reconstruct the rest of the state. For generalized, you should be able to rely on q and qd.

Sorry for leading you down a rabbit hole. Please let me know if some part of simulation still doesn't make sense for you.

namheegordonkim commented 1 year ago

I lied to you. For spring and pbd pipelines, q and qd are not sufficient to reconstruct physics state. For generalized, you should be able to rely on q and qd.

Well, FWIW the demo I shared is using generalized.

According to your comment, if I disabled gravity, hung the character up in the air with absolutely no collision possibilities, q and qd should be sufficient for reconstructing, but this isn't the case either :( I do think kinematics.forward() has a bug in it as mentioned above.

erikfrey commented 1 year ago

Hi @namheegordonkim - no problem, let's get to the bottom of this. First off, just so I understand, where you say:

Well, FWIW the demo I shared is using generalized.

You mean this colab? As far as I can tell you're using the positional backend, right? Just to be sure I fired up your colab and switched it to 'generalized' and sure enough, the assert goes away:

# using kinematics.forward(env.sys, q1, qd1) should yield x1, xd1.
x3, xd3 = jax.vmap(kinematics.forward, in_axes=(None, 0, 0,))(env.sys, q1, qd1)

tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-3), x1, x3)

That no longer throws after switching to generalized.

namheegordonkim commented 1 year ago

@erikfrey

I shared two colabs.

Here's the minimal example I promised: https://colab.research.google.com/drive/1lAqGR4mpd4EX4eeXhrlR4CzQEoQncbZk?usp=sharing

Not to hijack Erwin's question--I further narrowed it down to kinematics.forward(). Repro: https://colab.research.google.com/drive/1P3mtOCNTngZ6A_CJtjJdWzn4zX4TPcda?usp=sharing

I was specifically referring to the first one.

erikfrey commented 1 year ago

Oh got it, thanks for clearing that up. This is helping me sharpen how I talk about Brax. So the behavior in the generalized colab is expected (by me at least, haha!). We have not done any work to guarantee Brax is deterministic, that it will produce the same trajectory across diverse hardware (the typical case), or the exact same trajectory across states that were created via step vs. via init.

I agree this would be a nice property for Brax and some engines go to through the trouble to make explicit claims here, so we'll add that to our TODO.

One thing that I am certain of is that init vs step will produce slightly different mass matrices. If you would like to remove this difference, you can try setting matrix_inv_iterations to zero, as this will force brax to use the same matrix inverse operation for both step and init. This will slow down simulation though!

For now I think your best bet may be to store the entire State struct.

erikfrey commented 1 year ago

Also please do let me know if you have a repro for forward kinematics that shows it's broken! It's pretty well tested so I'd be surprised if it's incorrect. But I love surprises!