Open Conchylicultor opened 2 years ago
I don't get it. I can inherit from collections.abc.Mapping
and implement your own __getitem__
, __iter__
. I would even say it is the expected design! I don't get why there would be any blocking point here.
I can't inherit from collections.abc.Mapping
, because I already implement __getitem__
and _iter__
from another use-case.
Basically Mapping
expect __iter__
to returns keys
(which can then be passed to __getitem__
), but my custom class uses __iter__
to slice array (like iter(np.array(...))
). So the 2 use-cases are incompatible
Another issue with collections.abc.Mapping
is that it does not allow forwarding static metadata (like jax.tree_map
does).
Ok I understanding now. I assume you are not compatible with collections.abc.Sequence
either ? Because tree
recognizes it.
Anyway you are right. Reading the doc of jax.tree_map
, I really like the way it handles additional class registration. It is a shame it is not a standalone package because it seems to supersede tree
and could replace it. What I don't like with tree
is the lack of performance. I find it quite slow to do extremely basic operations because it mixes python and cython. Do you know if jax.tree_map
is faster ?
Here I think we can refer to torch's pytree and jax's libtree in terms of design. They provide a register method so that custom classes can be processed after flattening, and can be subsequently unflattened to restore them to their original state.
I think this is quite important for dmtree for supporting custom containers defined by user, and some wrapped libraries like treevalue
.
I also ran into this problem where the current implementation doesn't nicely handle initialization of custom PyTree objects.
To compare, the jax
implementation allows you to pass auxiliary data (stored in the PyTreeDef return value) when flattening and unflattening PyTrees. This is useful for storing some class metadata that shouldn't be interpreted as leaves.
This is not possible with dm-tree using tree.flatten
and tree.unflatten_as
, as all initialization data of your custom class must be contained in the leaves...
A solution might be to dynamically generate classes using a metaclass or a call to type
to couple this metadata to the tree-structure (haven't tried though), but this seems a bit hacky and I would prefer the jax
solution.
Currently,
chex.assert_xyz
usestree
under the hood. This make it impossible to use it with Jax trees (jax.tree_utils.register
) which do not supportcollections.abc.Mapping
.I have a dataclass which already define
__getitem__
,__iter__
for numpy-like operation, so I cannot supportcollections.abc.Mapping
. I would like to use it with tree too, sochex.assert_xyz
works.