google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Best way to store State #368

Closed Jogima-cyber closed 1 year ago

Jogima-cyber commented 1 year ago

Hello, so my issue is related to this issue, I want to do some offline things where I need to reset the simulator to past states, using the generalized backend, and it was stated that I should store the entire State. My question is how to store this since State is a complex Python structure, the most naïve way would be to store it in a Python list, but this would be awful to sample from this Python list. For me, the best thing to do would be to convert the State structure to a list of many tensors, what do you think about this approach?

erikfrey commented 1 year ago

Hello! Yes, Jax comes with some handy tools to manipulate or even discard the tree structure, for example:

state_leaves = jax.tree_leaves(state)

Gives you just the leaves in a list. There's probably a way to reconstruct the tree from the list, too.

Or if what you want is the list for the batch dimension of State, you can do:

from jax import numpy as jp
states = [jax.tree_map(lambda a: jp.take(a, i, axis=0), x) for i in range(batch_size)]

That kind of stuff. I'm not sure if this is what you're looking for? Not entirely sure what you mean by "awful to sample from"?

Jogima-cyber commented 1 year ago

Thank you! Yes, that's exactly the kind of thing I was looking for. I'm gonna use the list for the batch dimension of State.

What I meant by "awful to sample from" is the wall-clock time of the operation, I may be wrong, but I think wall-clock time is way better for sampling from tensors structures with libraries such as numpy rather than sampling from Python data structures that are agnostic to the data structure like deque or list. That's why in the end I'm gonna use the tree_leaves to discard the tree structure as you suggested, but reconstruction seems a little bit painful.