jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

`tree_transpose` does not validate tree structure of input #19810

Open jakevdp opened 9 months ago

jakevdp commented 9 months ago

It looks like only the number of nodes is validated, meaning you can transpose trees that have the wrong input structure:

import jax

outer_treedef = jax.tree_util.tree_structure(['*', '*', '*'])
inner_treedef = jax.tree_util.tree_structure(('*', '*'))

tree = [1, [2, [3, [4, [5, [6]]]]]]  # Definitiely not outer_treedef x inner_treedef

print(jax.tree_util.tree_transpose(outer_treedef, inner_treedef, tree))
# ([1, 3, 5], [2, 4, 6])
vfdev-5 commented 4 months ago

As simple fix to this issue, I was thinking to add the following in the current implementation:

  inner_size = inner_treedef.num_leaves
  outer_size = outer_treedef.num_leaves
  expected_treedef = outer_treedef.compose(inner_treedef)
  if treedef.num_leaves != (inner_size * outer_size) or expected_treedef != treedef:
    raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")

Original code: https://github.com/google/jax/blob/0da9b69285811446f069e0ef765cd8e0a8100bf4/jax/_src/tree_util.py#L378-L395

This leads to a more strict treedef verification and previous code where we can freely use tuple/list in the treedef and the input, now this will raise an error:

import jax
outer_treedef = jax.tree_util.tree_structure(['*', '*', '*'])
inner_treedef = jax.tree_util.tree_structure(['*', '*'])   # Here inner tree def uses list

tree = [(1, 2), (3, 4), (5, 6)]   # but input inner tree has tuples 
print(jax.tree_util.tree_transpose(outer_treedef, inner_treedef, tree))
# on main: [[1, 3, 5], [2, 4, 6]]
# but with `expected_treedef != treedef`, we have an error like
# TypeError: Mismatch
# PyTreeDef([(*, *), (*, *), (*, *)])
#  != 
 #PyTreeDef([[*, *], [*, *], [*, *]])

Another point is that, there are also 3 tests from the tests are failing, for example: testTranspose21 (other failing tests testTransposeWithCustomObject, testTranspose22) where we have the following:

pytree_to_transpose: FlatCache([1, 1, 1])
outer_treedef: PyTreeDef(CustomNode(FlatCache[PyTreeDef(*)], [*]))
inner_treedef: PyTreeDef([*, *, *])
treedef: PyTreeDef(CustomNode(FlatCache[PyTreeDef([*, *, *])], [*, *, *]))
expected_treedef: PyTreeDef(CustomNode(FlatCache[PyTreeDef(*)], [[*, *, *]]))

Another observation is that inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0]) would also fail for the same test input if we set input inner_treedef as None. Error message:

ValueError: Mismatch custom node data: PyTreeDef(*) != PyTreeDef([*, *, *]); value: FlatCache([1, 1, 1]).

Any hints, thoughts @jakevdp (when you have time to check this ticket) ?