pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/pytreeclass
pytreeclass
is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy
, dataclasses
, and others.
See documentation and π³ Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a π.
```python import jax import jax.numpy as jnp import pytreeclass as tc @tc.autoinit class Tree(tc.TreeClass): a: float = 1.0 b: tuple[float, float] = (2.0, 3.0) c: jax.Array = jnp.array([4.0, 5.0, 6.0]) def __call__(self, x): return self.a + self.b[0] + self.c + x tree = Tree() mask = jax.tree_map(lambda x: x > 5, tree) tree = tree\ .at["a"].set(100.0)\ .at["b"][0].set(10.0)\ .at[mask].set(100.0) print(tree) # Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.]) print(tc.tree_diagram(tree)) # Tree # βββ .a=100.0 # βββ .b:tuple # β βββ [0]=10.0 # β βββ [1]=3.0 # βββ .c=f32[3](ΞΌ=36.33, Ο=45.02, β[4.00,100.00]) print(tc.tree_summary(tree)) # βββββββ¬βββββββ¬ββββββ¬βββββββ # βName βType βCountβSize β # βββββββΌβββββββΌββββββΌβββββββ€ # β.a βfloat β1 β β # βββββββΌβββββββΌββββββΌβββββββ€ # β.b[0]βfloat β1 β β # βββββββΌβββββββΌββββββΌβββββββ€ # β.b[1]βfloat β1 β β # βββββββΌβββββββΌββββββΌβββββββ€ # β.c βf32[3]β3 β12.00Bβ # βββββββΌβββββββΌββββββΌβββββββ€ # βΞ£ βTree β6 β12.00Bβ # βββββββ΄βββββββ΄ββββββ΄βββββββ # ** pass it to jax transformations ** # works with jit, grad, vmap, etc. @jax.jit @jax.grad def sum_tree(tree: Tree, x): return sum(tree(x)) print(sum_tree(tree, 1.0)) # Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.]) ``` |
Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass
no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
```python import jax import pytreeclass as tc class Counter(tc.TreeClass): def __init__(self, calls: int = 0): self.calls = calls def increment(self): self.calls += 1 counter = Counter() # Counter(calls=0) ``` |
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at
. To achieve this we can use .at[method_name].__call__(*args,**kwargs)
, this functional call will return the value of this call and a new model instance with the update state.
```python @jax.jit def update(counter): value, new_counter = counter.at["increment"]() return new_counter for i in range(10): counter = update(counter) print(counter.calls) # 10 ``` |
Num of layers | Flax/tc time |
Equinox/tc time |
10 | 1.427 | 6.671 |
100 | 1.1130 | 2.714 |