Closed JesseFarebro closed 1 year ago
@PaParaZz1 can you take a look?
cc @XuehaiPan
One solution is to register the TreeValue
classes from treevalue
as JAX PyTree node type. FYI, this will need to register all possible classes, such as TreeValue
and FastTreeValue
. Because the JAX PyTree registry lookup uses type(node) is registered_type
rather than isinstance(node, registered_type
.
In _to_dm
:
we are returning a namedtuple of TreeValue
instances, which are non-jitable.
IMO, we'd better use standard Python containers (e.g., dict
s or namedtuple
s) rather than TreeValue
instances in our public API. The standard Python containers always have first-party support for many pytree libraries (jax
, torch
, dm-tree
, optree
).
Also, note that treevalue
only supports nested dict
s with str
keys. It does not support arbitrary nested Python containers:
In [1]: import treevalue
In [2]: tree = {1: 'a', 2: 'b'}
In [3]: treevalue.FastTreeValue(tree)
TypeError: Expected unicode, got int
In [4]: tree = [{'a': 1}, {'a', 2}]
In [5]: treevalue.FastTreeValue(tree)
TypeError: Unknown initialization type for tree value - 'list'.
@PaParaZz1 can you take a look?
@Hansbug and I are working to fix this compatibility problem with JAX. At present, it seems that the solution should be to register TreeValue
in JAX.
We are adding penetrate
function in order to make jax.jit support FastTreeValue
(see: https://github.com/opendilab/treevalue/pull/77 ). Here is the usage: https://opendilab.github.io/treevalue/dev/wrap/api_doc/tree/tree.html#penetrate
This will be released in the next version.
import jax
import numpy as np
from treevalue import FastTreeValue, PENETRATE_SESSIONID_ARGNAME, penetrate
@penetrate(jax.jit, static_argnames=PENETRATE_SESSIONID_ARGNAME)
def double(x):
return x * 2
t = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': 233,
'y': np.random.randn(2, 3)
}
})
print(t)
print(double(t))
print(double(t + 1))
Another solution based on nativa jax
register
import jax
import numpy as np
from treevalue import FastTreeValue, flatten, unflatten, TreeValue
def flatten_treevalue(container):
contents = []
paths = []
for path, value in flatten(container):
paths.append(path)
contents.append(value)
return contents, (type(container), paths)
def unflatten_treevalue(aux_data, flat_contents):
type_, paths = aux_data
return unflatten(zip(paths, flat_contents), return_type=type_)
jax.tree_util.register_pytree_node(TreeValue, flatten_treevalue, unflatten_treevalue)
jax.tree_util.register_pytree_node(FastTreeValue, flatten_treevalue, unflatten_treevalue)
data = {
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': 233,
'y': np.random.randn(2, 3)
}
}
t = FastTreeValue(data)
@jax.jit
def double(x):
return x * 2
print(double(t))
Now treevalue 1.4.7
can support the usage through jax.jit
@JesseFarebro
import jax
from treevalue import FastTreeValue
d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = FastTreeValue(d)
@jax.jit
def double(x):
return x * 2
if __name__ == '__main__':
print(double(t))
If you need to register custom treevalue class, just use register_integrate_container
import jax
from treevalue import FastTreeValue, register_treevalue_class
class MyTreeValue(FastTreeValue):
pass
register_treevalue_class(MyTreeValue)
d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = MyTreeValue(d)
@jax.jit
def double(x):
return x * 2
if __name__ == '__main__':
print(double(t))
Hi all,
Thanks for the swift action on this issue. One comment: I agree with @XuehaiPan's comments here, https://github.com/sail-sg/envpool/pull/249#issuecomment-1445735615, that standard container types should be used for the public-facing API.
I appreciate the emphasis on performance but I think the tradeoff for user-facing APIs isn't worth it. Another example of custom tree-like data structures getting in the way can be seen in Flax's recent move away from their custom FrozenDict
structure to regular dicts. There were some issues irrespective of immutability that spurred on this change (e.g., see long-standing issues in Optax RE: Flax FrozenDict
).
Describe the bug
It seems the move away from dm-tree caused some issues as
TreeValue
doesn't register itself as a valid PyTree node.To Reproduce
This is a direct rip from your XLA documentation:
Reason and Possible fixes
Checklist