Closed gRox167 closed 6 months ago
After looking at tree_map
function, I came up with an idea of separate output into different PyTree. Currently only support tuple output of a function.
def tree_transpose_map(
func: Callable[..., U],
tree: PyTree[T],
*rests: PyTree[S],
is_leaf: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> Tuple[PyTree, ...]:
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
output_tuple = zip(*map(func, *flat_args))
return tuple(treespec.unflatten(o) for o in output_tuple)
Hi @gRox167, thanks for raising this. I think tree_transpose_map
(or called something else) would be a great feature.
Currently only support tuple output of a function.
To support functions that return an arbitrary nested structure, we should specify the inner pytree structure.
def tree_transpose_map(
func: Callable[..., PyTree[U]],
tree: PyTree[T],
*rests: PyTree[S],
inner_treespec: PyTreeSpec | None = None,
is_leaf: Callable[[T], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = '',
) -> PyTree[U]: # PyTree[PyTree[U]]
leaves, outer_treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
if outer_treespec.num_leaves == 0:
raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.')
flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests]
outputs = list(map(func, *flat_args))
if inner_treespec is None:
inner_treespec = tree_structure(
outputs[0],
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
if inner_treespec.num_leaves == 0:
raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.')
grouped = [inner_treespec.flatten_up_to(o) for o in outputs]
transposed = zip(*grouped)
subtrees = map(outer_treespec.unflatten, transposed)
return inner_treespec.unflatten(subtrees)
In [1]: def foo(x):
...: return x, 2 * x
...:
In [2]: tree_transpose_map(foo, [(1, 2), 3, 4])
Out[2]: ([(1, 2), 3, 4], [(2, 4), 6, 8])
In [3]: def bar(x):
...: return {'x': x, 'x_double': 2 * x}
...:
In [4]: tree_transpose_map(bar, [(1, 2), 3, 4])
Out[4]: {'x': [(1, 2), 3, 4], 'x_double': [(2, 4), 6, 8]}
In [5]: def nested(x):
...: return {'x': (x, x), 'x_double': (2 * x, 2 * x)}
...:
In [6]: tree_transpose_map(nested, [(1, 2), 3, 4], inner_treespec=tree_structure({'x': 1, 'x_double': 2}))
Out[6]:
{
'x': [((1, 1), (2, 2)), (3, 3), (4, 4)],
'x_double': [((2, 2), (4, 4)), (6, 6), (8, 8)]
}
Thanks for all your brilliant contribution! tree_transpose_map
works perfectly fine for my use case. Please let me know if you need any further help such as documentation or examples.
Required prerequisites
Motivation
Thanks for all the brilliant work of contributors of this repo! When we have a multiple output function
And we apply a
tree_map
:the output will be a
PyTree
of typeTuple[int,int]
. However, if we want to get 2 separate trees of the same structure withPyTree[int]
, we can only do a separatetree_transpose
however, when the returned value have further nested structure, this could be tedious as we need to do another
is_leaf
to determine the leaf node.In the situation of separate return value of multiple output function, we know the existing PyTree structure, so it is possible to directly separate the PyTree.
There is a function called
equinox.partition
which could potentially serve this purpose, but it only support splitting pytree into 2 parts, and we still need to write a filter function to determine how to split.Solution
In the best situation,
and for mapping object
Alternatives
No response
Additional context
If you have any idea or suggestion please feel free to comment, I can work on implement this feature if I have a guide on how to do it.