jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Improve error in `jax.tree_util.tree_map` for variadic `Pytree` arguments #17798

Open femtomc opened 1 year ago

femtomc commented 1 year ago

I'm debugging a part of my library, and I'm running into an issue -- and I unfortunately do not know how to proceed because the error I'm receiving is sort of opaque.

File ~/miniconda3/envs/py311/lib/python3.11/site-packages/jax/_src/tree_util.py:243, in <listcomp>(.0)
    210 """Maps a multi-input function over pytree args to produce a new pytree.
    211 
    212 Args:
   (...)
    240   [[5, 7, 9], [6, 1, 2]]
    241 """
    242 leaves, treedef = tree_flatten(tree, is_leaf)
--> 243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Expected list, got (HierarchicalChoiceMap(trie=Trie(inner={'reflection_point': ValueChoiceMap(value=Traced<ShapedArray(float32[361])>with<DynamicJaxprTrace(level=2/0)>)})), HierarchicalChoiceMap(trie=Trie(inner={'reflection_point': ValueChoiceMap(value=Traced<ShapedArray(float32[361])>with<DynamicJaxprTrace(level=2/0)>)}))).

Now, separately -- I think I've figured out that this is an error related to the PyTreeDef structures for each of the instances here?

Nonetheless, it's a bit tough to figure out what's going on.

When I attempted to debug my PyTreeDef instances, I also was unable to manually tell if there was a discrepancy:

PyTreeDef(CustomNode(HierarchicalChoiceMap[()], [CustomNode(Trie[()], [CustomNode(HashableDict[['outlier', 'reflection_or_outlier']], [CustomNode(ValueChoiceMap[()], [*]), CustomNode(SwitchChoiceMap[()], [*, [CustomNode(HierarchicalChoiceMap[()], [CustomNode(Trie[()], [CustomNode(HashableDict[['reflection_point']], [CustomNode(ValueChoiceMap[()], [*])])])]), CustomNode(HierarchicalChoiceMap[()], [CustomNode(Trie[()], [CustomNode(HashableDict[['reflection_point']], [CustomNode(ValueChoiceMap[()], [*])])])])]])])])]))

PyTreeDef(CustomNode(HierarchicalChoiceMap[()], [CustomNode(Trie[()], [CustomNode(HashableDict[['outlier', 'reflection_or_outlier']], [CustomNode(ValueChoiceMap[()], [*]), CustomNode(SwitchChoiceMap[()], [*, (CustomNode(HierarchicalChoiceMap[()], [CustomNode(Trie[()], [CustomNode(HashableDict[['reflection_point']], [CustomNode(ValueChoiceMap[()], [*])])])]), CustomNode(HierarchicalChoiceMap[()], [CustomNode(Trie[()], [CustomNode(HashableDict[['reflection_point']], [CustomNode(ValueChoiceMap[()], [*])])])]))])])])]))

To me, these look identical -- so I can't tell if actually I'm running into a real bug -- or what the other is telling me.

femtomc commented 1 year ago

😆 I think I found it ...

in [*])])])])]])])])])) vs. [*])])])]))])])])])) there is a single parens which is not a bracket.

@jakevdp this is related to the other issue I opened #17663 (where, similarly -- I end up having to compare these structures and I really wish there was a utility to help me do it).