metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
136 stars 6 forks source link

[Feature Request] Add a function to transpose the `tree_map`ed values of a multiple output function #126

Closed gRox167 closed 6 months ago

gRox167 commented 6 months ago

Required prerequisites

Motivation

Thanks for all the brilliant work of contributors of this repo! When we have a multiple output function

def foo(x):
    return x, 2*x

And we apply a tree_map:

tree_map(foo, [(1,2), 3,4])

the output will be a PyTree of type Tuple[int,int]. However, if we want to get 2 separate trees of the same structure with PyTree[int], we can only do a separate tree_transpose

    x, x_double = tree_transpose(
        tree_structure([(1,2), 3,4]),
        tree_structure(("*","*")),
        tree_map(foo, [(1,2), 3,4])
    )

however, when the returned value have further nested structure, this could be tedious as we need to do another is_leaf to determine the leaf node.

In the situation of separate return value of multiple output function, we know the existing PyTree structure, so it is possible to directly separate the PyTree.

There is a function called equinox.partition which could potentially serve this purpose, but it only support splitting pytree into 2 parts, and we still need to write a filter function to determine how to split.

Solution

In the best situation,

def foo(x):
    return x, 2*x
t = [(1,2), 3,4]
x, x_double = tree_map(foo, t, transpose_output_tree_to_outer_tree = True)

and for mapping object

def foo(x):
    return {"x":x, "x_double":2*x}
t = [(1,2), 3,4]
x_dict= tree_map(foo, t, transpose_output_tree_to_outer_tree = True)
# x_dict will be {"x":[(1,2), 3,4], "x_double":[(2,4), 6,8]}

Alternatives

No response

Additional context

If you have any idea or suggestion please feel free to comment, I can work on implement this feature if I have a guide on how to do it.

gRox167 commented 6 months ago

After looking at tree_map function, I came up with an idea of separate output into different PyTree. Currently only support tuple output of a function.

def tree_transpose_map(
    func: Callable[..., U],
    tree: PyTree[T],
    *rests: PyTree[S],
    is_leaf: Callable[[T], bool] | None = None,
    none_is_leaf: bool = False,
    namespace: str = "",
) -> Tuple[PyTree, ...]:
    leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
    flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
    output_tuple = zip(*map(func, *flat_args))
    return tuple(treespec.unflatten(o) for o in output_tuple)
XuehaiPan commented 6 months ago

Hi @gRox167, thanks for raising this. I think tree_transpose_map (or called something else) would be a great feature.

Currently only support tuple output of a function.

To support functions that return an arbitrary nested structure, we should specify the inner pytree structure.

def tree_transpose_map(
    func: Callable[..., PyTree[U]],
    tree: PyTree[T],
    *rests: PyTree[S],
    inner_treespec: PyTreeSpec | None = None,
    is_leaf: Callable[[T], bool] | None = None,
    none_is_leaf: bool = False,
    namespace: str = '',
) -> PyTree[U]:  # PyTree[PyTree[U]]
    leaves, outer_treespec = tree_flatten(
        tree,
        is_leaf=is_leaf,
        none_is_leaf=none_is_leaf,
        namespace=namespace,
    )
    if outer_treespec.num_leaves == 0:
        raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.')
    flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests]
    outputs = list(map(func, *flat_args))

    if inner_treespec is None:
        inner_treespec = tree_structure(
            outputs[0],
            is_leaf=is_leaf,
            none_is_leaf=none_is_leaf,
            namespace=namespace,
        )
    if inner_treespec.num_leaves == 0:
        raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.')

    grouped = [inner_treespec.flatten_up_to(o) for o in outputs]
    transposed = zip(*grouped)
    subtrees = map(outer_treespec.unflatten, transposed)
    return inner_treespec.unflatten(subtrees)
In [1]: def foo(x):
   ...:     return x, 2 * x
   ...:     

In [2]: tree_transpose_map(foo, [(1, 2), 3, 4])
Out[2]: ([(1, 2), 3, 4], [(2, 4), 6, 8])

In [3]: def bar(x):
   ...:     return {'x': x, 'x_double': 2 * x}
   ...:     

In [4]: tree_transpose_map(bar, [(1, 2), 3, 4])
Out[4]: {'x': [(1, 2), 3, 4], 'x_double': [(2, 4), 6, 8]}

In [5]: def nested(x):
   ...:     return {'x': (x, x), 'x_double': (2 * x, 2 * x)}
   ...:     

In [6]: tree_transpose_map(nested, [(1, 2), 3, 4], inner_treespec=tree_structure({'x': 1, 'x_double': 2}))
Out[6]:
{
    'x': [((1, 1), (2, 2)), (3, 3), (4, 4)],
    'x_double': [((2, 2), (4, 4)), (6, 6), (8, 8)]
}
gRox167 commented 6 months ago

Thanks for all your brilliant contribution! tree_transpose_map works perfectly fine for my use case. Please let me know if you need any further help such as documentation or examples.