brentyi / jax_dataclasses

Pytrees + dataclasses ❤️
MIT License
61 stars 6 forks source link
dataclasses jax python

jax_dataclasses

build mypy lint codecov

Overview

jax_dataclasses provides a simple wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

Distinguishing features include:

Installation

In Python >=3.7:

pip install jax_dataclasses

We can then import:

import jax_dataclasses as jdc

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass: jdc.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree node.

We also provide several aliases: jdc.[field, asdict, astuples, is_dataclass, replace] are identical to their counterparts in the standard dataclasses library.

Static fields

To mark a field as static (in this context: constant at compile-time), we can wrap its type with jdc.Static[]:

@jdc.pytree_dataclass
class A:
    a: jax.Array
    b: jdc.Static[bool]

In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

Bonus: if you like jdc.Static[], we also introduce jdc.jit(). This enables use in function signatures, for example:

@jdc.jit
def f(a: jax.Array, b: jdc.Static[bool]) -> jax.Array:
  ...

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, jdc.copy_and_mutate (a) makes a copy of a pytree and (b) returns a context in which any of that copy's contained dataclasses are temporarily mutable:

import jax
from jax import numpy as jnp
import jax_dataclasses as jdc

@jdc.pytree_dataclass
class Node:
  child: jax.Array

obj = Node(child=jnp.zeros(3))

with jdc.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Does input validation by default: if the treedef, leaf shapes, or dtypes
  # of `obj` and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

You can also eschew the dataclass-style interface entirely; see how brax registers pytrees. This is a reasonable thing to prefer: it requires some floating strings and breaks things that I care about but you may not (like immutability and __post_init__), but gives more flexibility with custom __init__ methods.

Misc

jax_dataclasses was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback.