Closed ASKabalan closed 4 months ago
As a general statement, equinox/jax likes to operate on pytrees of arrays rather than arrays/lists of pytrees (with some exceptions) so many things will return 1 pytree object with each member variable being "stacked" rather than a "stack" of pytrees. So if your pytree is now containing stacked member variables, practically you can just extract things as you normally would from a pytree. So if you wanted a list of individual states you could just do
l = [jax.tree_map(lambda x: x[i], solution_struct.ys) for i in range(len(solution_struct.ts))]
or something to that effect.
Thank you, this works.
I am trying to use a Pytree state as a state
But I get one state (not a list) of ys
MWE :
What would be a best practice way to unpack the states to a list of states