Open jakevdp opened 9 months ago
As simple fix to this issue, I was thinking to add the following in the current implementation:
inner_size = inner_treedef.num_leaves
outer_size = outer_treedef.num_leaves
expected_treedef = outer_treedef.compose(inner_treedef)
if treedef.num_leaves != (inner_size * outer_size) or expected_treedef != treedef:
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
Original code: https://github.com/google/jax/blob/0da9b69285811446f069e0ef765cd8e0a8100bf4/jax/_src/tree_util.py#L378-L395
This leads to a more strict treedef verification and previous code where we can freely use tuple/list in the treedef and the input, now this will raise an error:
import jax
outer_treedef = jax.tree_util.tree_structure(['*', '*', '*'])
inner_treedef = jax.tree_util.tree_structure(['*', '*']) # Here inner tree def uses list
tree = [(1, 2), (3, 4), (5, 6)] # but input inner tree has tuples
print(jax.tree_util.tree_transpose(outer_treedef, inner_treedef, tree))
# on main: [[1, 3, 5], [2, 4, 6]]
# but with `expected_treedef != treedef`, we have an error like
# TypeError: Mismatch
# PyTreeDef([(*, *), (*, *), (*, *)])
# !=
#PyTreeDef([[*, *], [*, *], [*, *]])
Another point is that, there are also 3 tests from the tests are failing, for example: testTranspose21
(other failing tests testTransposeWithCustomObject
, testTranspose22
) where we have the following:
pytree_to_transpose: FlatCache([1, 1, 1])
outer_treedef: PyTreeDef(CustomNode(FlatCache[PyTreeDef(*)], [*]))
inner_treedef: PyTreeDef([*, *, *])
treedef: PyTreeDef(CustomNode(FlatCache[PyTreeDef([*, *, *])], [*, *, *]))
expected_treedef: PyTreeDef(CustomNode(FlatCache[PyTreeDef(*)], [[*, *, *]]))
Another observation is that inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0])
would also fail for the same test input if we set input inner_treedef
as None. Error message:
ValueError: Mismatch custom node data: PyTreeDef(*) != PyTreeDef([*, *, *]); value: FlatCache([1, 1, 1]).
Any hints, thoughts @jakevdp (when you have time to check this ticket) ?
It looks like only the number of nodes is validated, meaning you can transpose trees that have the wrong input structure: