Open namheegordonkim opened 1 year ago
Hello! Yes, q
and qd
(or qpos
and qvel
in mujoco terms) are all you need to recreate simulator state. That said, for performance reasons, Brax stores other data that it reuses from one step to the next. Knowing this, you have two options:
1) If you wish to store only the minimal state, q
and qd
, you can always recreate the full state by calling pipeline.init
to recreate the desired brax state, then call pipeline.step
afterwards.
2) If you can afford to save the full state, then you can skip calling pipeline.init
to recreate the state.
Thanks for the fast response, @erikfrey.
If you wish to store only the minimal state, q and qd, you can always recreate the full state by calling pipeline.init to recreate the desired brax state, then call pipeline.step afterwards.
This has not been the case for me unfortunately--at least in positional backend, using pipeline.init()
with q
and qd
did not result in the same matrices; I checked this by saving tuples of (State, (q, qd))
and comparing entries inside State
vs. pipeline.init(q, qd)
. Would be happy to provide a minimal example when I get the chance.
OK! It's our expectation that this would work, so please do share a repro and we'll take a look.
Here's the minimal example I promised: https://colab.research.google.com/drive/1lAqGR4mpd4EX4eeXhrlR4CzQEoQncbZk?usp=sharing
You will see that state2
is a reconstructed state from state1
's q
and qd
values, but after exact same actions are used for step()
, the two trajectories diverge, as visible in the rendered output. backend="positional"
also has the same issue.
Can you provide info on exactly what data is reused?
Brax stores other data that it reuses from one step to the next.
Not to hijack Erwin's question--I further narrowed it down to kinematics.forward()
.
Repro: https://colab.research.google.com/drive/1P3mtOCNTngZ6A_CJtjJdWzn4zX4TPcda?usp=sharing
Seems that using anything but the default q and qd results in unreproducible resulst of kinematics.forward()
.
@namheegordonkim I think you're onto something here. I agree the assert in your colab shouldn't fire. We'll take a closer look.
@erwincoumans it varies by pipeline. In generalized, for example, we iteratively update the inverse mass matrix from the previous step instead of recalculating it from scratch. We can certainly improve the documentation to make this clearer.
Great to hear that you'll be looking into this. I've wrestled with it for a few days and identified some sus lines. I hope these help!
Here, the parent link and the child link are held together by the joint. Initially x
is the world coordinate position of the parent link COM, but using x.vmap().do(j)
transforms it to the world coordinate position of the child link. The child link inherits the parent's world coordinate angles and velocities, so all that's left (theoretically) is to add the contributions of the angular velocity of the joint.
However, jax.vmap(jp.cross)(x.pos, jd.ang)
plainly seems wrong here: the linear velocity contribution via angular velocity should be done as a cross product between angular velocity and moment arm. First, x.pos
doesn't give you the local coordinate moment arm; sys.link.joint.pos
does. Next, the cross product isn't commutable so shouldn't the order between these two swtiched?
I did fool around with passing sys.link.joint
into scan.tree
call as another argument, but hadn't had success in replicating the correct linear velocities.
Oh whoops, you know, I lied to you. For spring and pbd pipelines, q
and qd
are not sufficient to reconstruct physics state. Those two pipelines can produce states with joint constraint violations which cannot be expressed in reduced coordinates. That is why everything lines up for you at init, but after a step you start to see the error - no joint constraints are violated at init, but are after a step.
For spring and pbd, you would want to use either x
and xd
or x_i
and xd_i
to reconstruct the rest of the state. For generalized, you should be able to rely on q and qd.
Sorry for leading you down a rabbit hole. Please let me know if some part of simulation still doesn't make sense for you.
I lied to you. For spring and pbd pipelines, q and qd are not sufficient to reconstruct physics state. For generalized, you should be able to rely on q and qd.
Well, FWIW the demo I shared is using generalized.
According to your comment, if I disabled gravity, hung the character up in the air with absolutely no collision possibilities, q
and qd
should be sufficient for reconstructing, but this isn't the case either :( I do think kinematics.forward()
has a bug in it as mentioned above.
Hi @namheegordonkim - no problem, let's get to the bottom of this. First off, just so I understand, where you say:
Well, FWIW the demo I shared is using generalized.
You mean this colab? As far as I can tell you're using the positional backend, right? Just to be sure I fired up your colab and switched it to 'generalized' and sure enough, the assert goes away:
# using kinematics.forward(env.sys, q1, qd1) should yield x1, xd1.
x3, xd3 = jax.vmap(kinematics.forward, in_axes=(None, 0, 0,))(env.sys, q1, qd1)
tree_map(lambda x, y: np.testing.assert_allclose(x, y, atol=1e-3), x1, x3)
That no longer throws after switching to generalized.
@erikfrey
I shared two colabs.
Here's the minimal example I promised: https://colab.research.google.com/drive/1lAqGR4mpd4EX4eeXhrlR4CzQEoQncbZk?usp=sharing
Not to hijack Erwin's question--I further narrowed it down to kinematics.forward(). Repro: https://colab.research.google.com/drive/1P3mtOCNTngZ6A_CJtjJdWzn4zX4TPcda?usp=sharing
I was specifically referring to the first one.
Oh got it, thanks for clearing that up. This is helping me sharpen how I talk about Brax. So the behavior in the generalized colab is expected (by me at least, haha!). We have not done any work to guarantee Brax is deterministic, that it will produce the same trajectory across diverse hardware (the typical case), or the exact same trajectory across states that were created via step vs. via init.
I agree this would be a nice property for Brax and some engines go to through the trouble to make explicit claims here, so we'll add that to our TODO.
One thing that I am certain of is that init vs step will produce slightly different mass matrices. If you would like to remove this difference, you can try setting matrix_inv_iterations to zero, as this will force brax to use the same matrix inverse operation for both step and init. This will slow down simulation though!
For now I think your best bet may be to store the entire State
struct.
Also please do let me know if you have a repro for forward kinematics that shows it's broken! It's pretty well tested so I'd be surprised if it's incorrect. But I love surprises!
I've mentioned this briefly in #157, but I believe this deserves a separate thread.
In certain applications like Go-Explore (https://www.nature.com/articles/s41586-020-03157-9), it's critical for the user to be able to reset the environment to the state corresponding to the gathered observation, i.e., qpos and qvel.
In my experiments with Brax v2, I haven't been able to successfully reconstruct the simulator state from qpos and qvel alone--I suspect the 2nd order information is also necessary. According to MuJoCo docs:
Ignoring
time
andudd_state
,act
needs to be taken into account. Not sure by what mechanism MuJoCo does this, but we can gather thatqpos
,qvel
, andact
should be necessary and sufficient for reconstructing the state.Current workaround for this is to actually store the states (i.e. in a batched instance of
State
class) gathered during exploration and to use that as the argument forstep()
. But this comes at a huge memory overhead.Any idea whether this will be made possible?