Open ayaka14732 opened 2 months ago
To be clear, are these the semantics you have in mind?
def stack_leaves(pytrees, axis):
return jax.tree.map(lambda *xs: jnp.stack(xs, axis), pytrees)
To be clear, are these the semantics you have in mind?
Yes
For something like this, I'd probably lean toward recommending users implement what they need via existing API composability, rather than providing a new API for something that can already be pretty succinctly expressed. What do you think?
Maybe adding tree util cookbook would be useful? @jakevdp
A pytree cookbook would be an interesting idea! This idea also came up in #20594. @ayaka14732, is that something you'd be interested in thinking about?
stack_leaves
: Stack the leaves of one or more PyTrees along a new axis.unstack_leaves
: Unstack the leaves of a PyTree.References: