google / etils

Collection of eclectic utils for python.
https://etils.readthedocs.io
Apache License 2.0
193 stars 20 forks source link

[Enhancement] Add `optree` integration to `etils.etree` #342

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

optree is a standalone package (like dm-tree) aimed to high-performance PyTree manipulation (like jax.tree_util). It offers similar APIs to jax.tree_util but better.

Some initial benchmark results:

Average Time Cost (↓) OpTree (v0.9.0) JAX XLA (v0.4.6) PyTorch (v2.0.0) TensorFlow Nest (v2.12.0) DM-Tree (v0.1.8)
Tree Flatten x1.00 2.33 22.05 1.38 1.12
Tree UnFlatten x1.00 2.69 4.28 13.69 16.23
Tree Flatten with Path x1.00 16.16 Not Supported 21.10 27.59
Tree Copy x1.00 2.56 9.97 9.62 11.02
Tree Map x1.00 2.56 9.58 9.16 10.62
Tree Map (nargs) x1.00 2.89 Not Supported 74.26 31.33
Tree Map with Path x1.00 7.23 Not Supported 40.78 19.66
Tree Map with Path (nargs) x1.00 6.56 Not Supported 69.63 29.61

We have already seen some etils folks get involved with optree and jax.tree_util discussions. I wonder if etils maintainers have interest to add optree to etils.etree.

Ref:

Conchylicultor commented 1 year ago

Good idea. Let me try this