google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.25k stars 2.68k forks source link

Adding `tree_util.stack_leaves()` and `tree_util.unstack_leaves()` #20934

Open ayaka14732 opened 2 months ago

ayaka14732 commented 2 months ago

References:

jakevdp commented 2 months ago

To be clear, are these the semantics you have in mind?

def stack_leaves(pytrees, axis):
  return jax.tree.map(lambda *xs: jnp.stack(xs, axis), pytrees)
ayaka14732 commented 2 months ago

To be clear, are these the semantics you have in mind?

Yes

jakevdp commented 2 months ago

For something like this, I'd probably lean toward recommending users implement what they need via existing API composability, rather than providing a new API for something that can already be pretty succinctly expressed. What do you think?

ASEM000 commented 2 months ago

Maybe adding tree util cookbook would be useful? @jakevdp

jakevdp commented 2 months ago

A pytree cookbook would be an interesting idea! This idea also came up in #20594. @ayaka14732, is that something you'd be interested in thinking about?