cgarciae / treeo

A small library for creating and manipulating custom JAX Pytree classes
https://cgarciae.github.io/treeo
MIT License
58 stars 4 forks source link

Stacking of Treeo.Tree #23

Closed peterroelants closed 1 year ago

peterroelants commented 1 year ago

I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
    name: str = to.field(node=False) # I am a static field!

persons = [
    Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
    Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
    Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
]

# Stack (struct of arrays instead of list of structs)
jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)

However, this fails with the following exception:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 18
     11     name: str = to.field(node=False) # I am a static field!
     13 persons = [
     14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
     15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
     16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
     17 ]
---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)

File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
    166 """Maps a multi-input function over pytree args to produce a new pytree.
    167 
    168 Args:
   (...)
    196   [[5, 7, 9], [6, 1, 2]]
    197 """
    198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
    166 """Maps a multi-input function over pytree args to produce a new pytree.
    167 
    168 Args:
   (...)
    196   [[5, 7, 9], [6, 1, 2]]
    197 """
    198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').

Versions used:

From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?

peterroelants commented 1 year ago

It seems to me that if I want to fix this I need to somehow map over PyTreeDefs, I created a question at the JAX Github here: https://github.com/google/jax/discussions/13768

cgarciae commented 1 year ago

I don't think you can collectively tree_map Pytrees with different static fields (node=False). There are ways to go around this but the outcome is undefined (which static value to choose?).

peterroelants commented 1 year ago

I don't think you can collectively tree_map Pytrees with different static fields (node=False).

It seems from the discussion at https://github.com/google/jax/discussions/13768 that this is indeed not possible.