sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.09k stars 100 forks source link

[BUG] Move away from dm-tree broke compatibility with Jax #246

Closed JesseFarebro closed 1 year ago

JesseFarebro commented 1 year ago

Describe the bug

It seems the move away from dm-tree caused some issues as TreeValue doesn't register itself as a valid PyTree node.

To Reproduce

This is a direct rip from your XLA documentation:

import envpool
import jax

env = envpool.make(
    "Pong-v5",
    env_type="dm",
    num_envs=2,
)
handle, recv, send, _ = env.xla()

def actor_step(iter, loop_var):
    handle0, states = loop_var
    action = 0
    handle1 = send(handle0, action, states.observation.env_id)
    handle1, new_states = recv(handle0)
    return handle1, new_states

@jax.jit
def run_actor_loop(num_steps, init_var):
    return jax.lax.fori_loop(0, num_steps, actor_step, init_var)

env.async_reset()
handle, states = recv(handle)
run_actor_loop(100, (handle, states))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret value of type <class 'treevalue.tree.tree.tree.TreeValue'> as an abstract array; it does not have a dtype attribute

Reason and Possible fixes

Checklist

Trinkle23897 commented 1 year ago

@PaParaZz1 can you take a look?

Benjamin-eecs commented 1 year ago

cc @XuehaiPan

XuehaiPan commented 1 year ago

One solution is to register the TreeValue classes from treevalue as JAX PyTree node type. FYI, this will need to register all possible classes, such as TreeValue and FastTreeValue. Because the JAX PyTree registry lookup uses type(node) is registered_type rather than isinstance(node, registered_type.

In _to_dm:

https://github.com/sail-sg/envpool/blob/cd2ece04945f9ae1970efa610d085c1fb20b2cb5/envpool/python/dm_envpool.py#L74-L90

we are returning a namedtuple of TreeValue instances, which are non-jitable.

IMO, we'd better use standard Python containers (e.g., dicts or namedtuples) rather than TreeValue instances in our public API. The standard Python containers always have first-party support for many pytree libraries (jax, torch, dm-tree, optree).

Also, note that treevalue only supports nested dicts with str keys. It does not support arbitrary nested Python containers:

In [1]: import treevalue

In [2]: tree = {1: 'a', 2: 'b'}

In [3]: treevalue.FastTreeValue(tree)
TypeError: Expected unicode, got int

In [4]: tree = [{'a': 1}, {'a', 2}]

In [5]: treevalue.FastTreeValue(tree)
TypeError: Unknown initialization type for tree value - 'list'.
PaParaZz1 commented 1 year ago

@PaParaZz1 can you take a look?

@Hansbug and I are working to fix this compatibility problem with JAX. At present, it seems that the solution should be to register TreeValue in JAX.

HansBug commented 1 year ago

We are adding penetrate function in order to make jax.jit support FastTreeValue (see: https://github.com/opendilab/treevalue/pull/77 ). Here is the usage: https://opendilab.github.io/treevalue/dev/wrap/api_doc/tree/tree.html#penetrate

This will be released in the next version.

import jax
import numpy as np

from treevalue import FastTreeValue, PENETRATE_SESSIONID_ARGNAME, penetrate

@penetrate(jax.jit, static_argnames=PENETRATE_SESSIONID_ARGNAME)
def double(x):
    return x * 2

t = FastTreeValue({
    'a': np.random.randint(0, 10, (2, 3)),
    'b': {
        'x': 233,
        'y': np.random.randn(2, 3)
    }
})

print(t)
print(double(t))
print(double(t + 1))
HansBug commented 1 year ago

Another solution based on nativa jax register

import jax
import numpy as np

from treevalue import FastTreeValue, flatten, unflatten, TreeValue

def flatten_treevalue(container):
    contents = []
    paths = []
    for path, value in flatten(container):
        paths.append(path)
        contents.append(value)

    return contents, (type(container), paths)

def unflatten_treevalue(aux_data, flat_contents):
    type_, paths = aux_data
    return unflatten(zip(paths, flat_contents), return_type=type_)

jax.tree_util.register_pytree_node(TreeValue, flatten_treevalue, unflatten_treevalue)
jax.tree_util.register_pytree_node(FastTreeValue, flatten_treevalue, unflatten_treevalue)

data = {
    'a': np.random.randint(0, 10, (2, 3)),
    'b': {
        'x': 233,
        'y': np.random.randn(2, 3)
    }
}
t = FastTreeValue(data)

@jax.jit
def double(x):
    return x * 2

print(double(t))
HansBug commented 1 year ago

Now treevalue 1.4.7 can support the usage through jax.jit @JesseFarebro

import jax

from treevalue import FastTreeValue

d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = FastTreeValue(d)

@jax.jit
def double(x):
    return x * 2

if __name__ == '__main__':
    print(double(t))

If you need to register custom treevalue class, just use register_integrate_container

import jax

from treevalue import FastTreeValue, register_treevalue_class

class MyTreeValue(FastTreeValue):
    pass

register_treevalue_class(MyTreeValue)

d = {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
t = MyTreeValue(d)

@jax.jit
def double(x):
    return x * 2

if __name__ == '__main__':
    print(double(t))
JesseFarebro commented 1 year ago

Hi all,

Thanks for the swift action on this issue. One comment: I agree with @XuehaiPan's comments here, https://github.com/sail-sg/envpool/pull/249#issuecomment-1445735615, that standard container types should be used for the public-facing API.

I appreciate the emphasis on performance but I think the tradeoff for user-facing APIs isn't worth it. Another example of custom tree-like data structures getting in the way can be seen in Flax's recent move away from their custom FrozenDict structure to regular dicts. There were some issues irrespective of immutability that spurred on this change (e.g., see long-standing issues in Optax RE: Flax FrozenDict).