google-deepmind / tree

tree is a library for working with nested data structures
https://tree.readthedocs.io
Apache License 2.0
933 stars 57 forks source link

Explicit way of register classes #76

Open Conchylicultor opened 2 years ago

Conchylicultor commented 2 years ago

Currently, chex.assert_xyz uses tree under the hood. This make it impossible to use it with Jax trees (jax.tree_utils.register) which do not support collections.abc.Mapping.

I have a dataclass which already define __getitem__, __iter__ for numpy-like operation, so I cannot support collections.abc.Mapping. I would like to use it with tree too, so chex.assert_xyz works.

duburcqa commented 2 years ago

I don't get it. I can inherit from collections.abc.Mapping and implement your own __getitem__, __iter__. I would even say it is the expected design! I don't get why there would be any blocking point here.

Conchylicultor commented 2 years ago

I can't inherit from collections.abc.Mapping, because I already implement __getitem__ and _iter__ from another use-case. Basically Mapping expect __iter__ to returns keys (which can then be passed to __getitem__), but my custom class uses __iter__ to slice array (like iter(np.array(...))). So the 2 use-cases are incompatible

Conchylicultor commented 2 years ago

Another issue with collections.abc.Mapping is that it does not allow forwarding static metadata (like jax.tree_map does).

duburcqa commented 2 years ago

Ok I understanding now. I assume you are not compatible with collections.abc.Sequence either ? Because tree recognizes it.

Anyway you are right. Reading the doc of jax.tree_map, I really like the way it handles additional class registration. It is a shame it is not a standalone package because it seems to supersede tree and could replace it. What I don't like with tree is the lack of performance. I find it quite slow to do extremely basic operations because it mixes python and cython. Do you know if jax.tree_map is faster ?

HansBug commented 1 year ago

Here I think we can refer to torch's pytree and jax's libtree in terms of design. They provide a register method so that custom classes can be processed after flattening, and can be subsequently unflattened to restore them to their original state.

I think this is quite important for dmtree for supporting custom containers defined by user, and some wrapped libraries like treevalue.

joeryjoery commented 1 year ago

I also ran into this problem where the current implementation doesn't nicely handle initialization of custom PyTree objects.

To compare, the jax implementation allows you to pass auxiliary data (stored in the PyTreeDef return value) when flattening and unflattening PyTrees. This is useful for storing some class metadata that shouldn't be interpreted as leaves.

This is not possible with dm-tree using tree.flatten and tree.unflatten_as, as all initialization data of your custom class must be contained in the leaves...

A solution might be to dynamically generate classes using a metaclass or a call to type to couple this metadata to the tree-structure (haven't tried though), but this seems a bit hacky and I would prefer the jax solution.