ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link
data dataclasses deep-learning jax machine-learning pipelines pytorch pytree tensorflow


[**Installation**](#installation) |[**Description**](#description) |[**Quick Example**](#quick_example) |[**StatefulComputation**](#stateful_computation) |[**Benchamrks**](#more) |[**Acknowledgements**](#acknowledgements) ![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_default.yml/badge.svg) ![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_jax.yml/badge.svg) ![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_numpy.yml/badge.svg) ![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_torch.yml/badge.svg) ![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-blue) ![codestyle](https://img.shields.io/badge/codestyle-black-black) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/intro.ipynb) [![Downloads](https://static.pepy.tech/badge/pytreeclass)](https://pepy.tech/project/pytreeclass) [![codecov](https://codecov.io/gh/ASEM000/pytreeclass/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/pytreeclass) [![Documentation Status](https://readthedocs.org/projects/pytreeclass/badge/?version=latest)](https://pytreeclass.readthedocs.io/en/latest/?badge=latest) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/pytreeclass) [![DOI](https://zenodo.org/badge/512717921.svg)](https://zenodo.org/badge/latestdoi/512717921) ![PyPI](https://img.shields.io/pypi/v/pytreeclass) [![CodeFactor](https://www.codefactor.io/repository/github/asem000/pytreeclass/badge)](https://www.codefactor.io/repository/github/asem000/pytreeclass)

πŸ› οΈ Installation

pip install pytreeclass

Install development version

pip install git+https://github.com/ASEM000/pytreeclass

πŸ“– Description

pytreeclass is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy, dataclasses, and others.

See documentation and 🍳 Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a 🌟.

⏩ Quick Example

```python import jax import jax.numpy as jnp import pytreeclass as tc @tc.autoinit class Tree(tc.TreeClass): a: float = 1.0 b: tuple[float, float] = (2.0, 3.0) c: jax.Array = jnp.array([4.0, 5.0, 6.0]) def __call__(self, x): return self.a + self.b[0] + self.c + x tree = Tree() mask = jax.tree_map(lambda x: x > 5, tree) tree = tree\ .at["a"].set(100.0)\ .at["b"][0].set(10.0)\ .at[mask].set(100.0) print(tree) # Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.]) print(tc.tree_diagram(tree)) # Tree # β”œβ”€β”€ .a=100.0 # β”œβ”€β”€ .b:tuple # β”‚ β”œβ”€β”€ [0]=10.0 # β”‚ └── [1]=3.0 # └── .c=f32[3](ΞΌ=36.33, Οƒ=45.02, ∈[4.00,100.00]) print(tc.tree_summary(tree)) # β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β” # β”‚Name β”‚Type β”‚Countβ”‚Size β”‚ # β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€ # β”‚.a β”‚float β”‚1 β”‚ β”‚ # β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€ # β”‚.b[0]β”‚float β”‚1 β”‚ β”‚ # β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€ # β”‚.b[1]β”‚float β”‚1 β”‚ β”‚ # β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€ # β”‚.c β”‚f32[3]β”‚3 β”‚12.00Bβ”‚ # β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€ # β”‚Ξ£ β”‚Tree β”‚6 β”‚12.00Bβ”‚ # β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜ # ** pass it to jax transformations ** # works with jit, grad, vmap, etc. @jax.jit @jax.grad def sum_tree(tree: Tree, x): return sum(tree(x)) print(sum_tree(tree, 1.0)) # Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.]) ```

πŸ“œ Stateful computations

Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state functionally can be achieved under jax.jit

```python import jax import pytreeclass as tc class Counter(tc.TreeClass): def __init__(self, calls: int = 0): self.calls = calls def increment(self): self.calls += 1 counter = Counter() # Counter(calls=0) ```

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at. To achieve this we can use .at[method_name].__call__(*args,**kwargs), this functional call will return the value of this call and a new model instance with the update state.

```python @jax.jit def update(counter): value, new_counter = counter.at["increment"]() return new_counter for i in range(10): counter = update(counter) print(counter.calls) # 10 ```

βž• Benchmarks

Benchmark flatten/unflatten compared to Flax and Equinox Open In Colab
CPUGPU
Benchmark simple training against `flax` and `equinox` Training simple sequential linear benchmark against `flax` and `equinox`
Num of layers Flax/tc time
Open In Colab
Equinox/tc time
Open In Colab
10 1.427 6.671
100 1.1130 2.714

πŸ“™ Acknowledgements