jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

Automatically treat dataclasses as pytrees #2371

Open shoyer opened 4 years ago

shoyer commented 4 years ago

JAX should automatically treat dataclasses as pytrees, so they don't have to be explicitly registered.

Ideally we would also support some syntax for non-differentiable parameters. Flax does so by adding custom metadata into dataclassess.Field.metadata with the special flax.struct.field constructor, which seems like a very clean way to do this.

I started working on this in a branch, but haven't tested anything so it very likely is entirely broken/non-functional! If somebody wants to finish this up it would be awesome :) https://github.com/google/jax/compare/master...shoyer:dataclasses-pytree

xref https://github.com/google/jax/issues/1808

mattjj commented 4 years ago

Thanks for starting this! Treating dataclasses as pytrees sgtm.

(Quick terminology aside: can a value really be "differentiable"? Maybe we should say "values that don't inhabit a vector space", or "non-vspace values" for short.)

The non-vspace values question is interestingly related, but could be considered separately from dataclasses. After all, if we think of dataclasses as just pytrees (i.e. isomorphic to tuples) as the title of this issue proposes, then asking asking about non-vspace values in dataclasses would be no different than asking the same question about pytrees in general. And since a leaf is a pytree, we can just ask the question in the case of a single scalar: what should grad(f)(1) be? It's currently an error, but you could imagine it giving some sentinel value (and correspondingly associating to all non-vspace values a trivial "unit" tangent space). That decision would automatically extend to nontrivial containers that include non-vspace values.

We could instead choose for dataclasses to act differently from pytrees, and then maybe do something interesting with their type annotations (which other pytree types don't have). But that seems like a different proposal from "treat dataclasses as pytrees", and may require more discussion (e.g. it may require we change our convention for grad(f)(1) as well for consistency).

Should we separate out the question of what one might do with dataclass metadata, and restrict this issue to being about treating dataclasses as pytrees? AIUI that's what your code is already up to!

TuanNguyen27 commented 4 years ago

I'd be interested in finishing this, could you outline the missing steps @shoyer & @mattjj ? :)

shoyer commented 4 years ago

This needs tests, and likely some iteration on the code from my branch (or a new implementation) to get things working. Note that the flattening/unflattening logic here needs to be written in C++.

awav commented 4 years ago

Hello @shoyer! Thanks for the initiative! This is a crucial feature and one of the main things that stop me from using JAX today in day to day research. I would like to finish this work and I will able to do that after NuerIPS2020 deadline. You mentioned that this PR lacks testing, is there anything else? Does it still need [un]flatten implementations? Thanks!

tomhennigan commented 4 years ago

For context in the various other tree libraries (tf.nest and dm-tree) we have pushed back on dataclasses automatically being treated as nests because for these "struct" types it is not clear if this (treating these as containers) is intended in all cases. These libraries treat namedtuple/typing.NamedTuple as trees because it has a weird duality of being a structure (supporting named attribute access) and iterable (supporting x.__iter__()). attr.s (basically dataclasses) slipped in as a historical accident (and since these APIs don't have a way for users to register custom types).

Instead of treeating all dataclasses as jaxtrees, could we instead create a drop in replacement for dataclass for users who know they want this behavior? Here's an example implementation which is basically a fork of flax.struct:

from dataclasses import dataclass
from typing import Any, Type, TypeVar
import jax
import jax.numpy as jnp

T = TypeVar("T")

def jax_tree(cls: T) -> T:
  is_data = lambda x: isinstance(x, jnp.ndarray) or hasattr(x, '__jax_dataclass')

  def flatten_fun(obj):
    meta = {}
    data = {}
    for k, v in obj.__dict__.items():
      if isinstance(v, list):  # We can add other containers here.
        are_data = list(map(is_data, v))
        assert all(are_data) or not any(are_data)
        data[k] = v
      elif is_data(v):
        data[k] = v
      else:
        meta[k] = v
    meta['__data_keys'] = list(data.keys())
    data = list(data.values())
    return tuple(data), tuple(meta.items())

  def unflatten_fun(meta, data):
    meta = dict(meta)
    data = dict(zip(meta.pop('__data_keys'), data))
    return cls(**meta, **data)

  jax.tree_util.register_pytree_node(cls, flatten_fun, unflatten_fun)

  cls.__jax_dataclass = True
  return dataclass(cls)

jax.tree = jax_tree

@jax.tree
class Bar:
  c: jnp.ndarray

@jax.tree
class Foo(object):
  a: jnp.ndarray
  b: Bar

>>> foo = Foo(jnp.ones([]), Bar(jnp.zeros([])))
>>> jax.tree_leaves(foo)
[DeviceArray(1., dtype=float32), DeviceArray(0., dtype=float32)]
NeilGirdhar commented 4 years ago

@tomhennigan Unlike flax's version, it looks like you're trying to guess which components are categorized as auxiliary parameters and which are pytree-like? I tried to do that, but unfortunately there may be components that are both hashable and pytree-like, for example, elements of type float or int. In that case, how would you know how to categorize them? Maybe the user wants these to be sometimes traced, or maybe the user wants these to passed in statically so that they can be used as the bound of a scan for example.

I ended up modifying flax.struct.dataclass into tjax.dataclass, and I went with something pretty similar. The only things I changed were I made them correctly ignore class variables, and to not encode member variables that aren't initialized (since those are typically initialized in __post_init__):

    for field_info in dataclasses.fields(data_clz):
        if not field_info.init:
            continue
        if field_info.metadata.get('pytree_like', True):
            tree_fields.append(name)
        else:
            hashed_fields.append(name)

Also, I fixed a problem where they don't preserve metadata passed into their field factory:

def field(pytree_like: bool = True, **kwargs: Any) -> dataclasses.Field:
    return dataclasses.field(metadata={**kwargs.pop('metadata', {}),
                                       'pytree_like': pytree_like},
                             **kwargs)

Other minor things I did were

Hope any of this is useful :)

If there's one change that I would love to see, but haven't done in my own code yet, it's to declare JAX dataclasses using a mixin rather than a decorator. The benefits are that

awav commented 4 years ago

Hello all!

@tomhennigan, @NeilGirdhar thanks for your input.

For context in the various other tree libraries (tf.nest and dm-tree) we have pushed back on dataclasses automatically being treated as nests because for these "struct" types it is not clear if this (treating these as containers) is intended in all cases

Can the same reasoning be applied to lists, tuples and namedtuples. Often a user does not want to differentiate through these structures as well.

Instead of treeating all dataclasses as jaxtrees, could we instead create a drop in replacement for dataclass for users who know they want this behavior?

Agreed and I think the same option should be available for other structures. Actually, this is another important missing bit in JAX! I raised that issue before https://github.com/google/jax/issues/2588.

Example: non-parameteric Gaussian process model with trainable hyper-parameters lengthscale and variance of the squared exponential kernel. Often, there is a need to experiment with the model in a way that we compute gradients w.r.t. to only variance, only lengthscale or both variance and lengthscale. I had to write different code for each case specifically, which is super annoying considering that other frameworks (TF and PyTorch) support trainability of tensors out of the box.

Long story short: as I think, two features will bring more users to JAX:

What are the next steps? @shoyer, @mattjj

NathanHowell commented 3 years ago

@mattjj @shoyer I've opened a PR with an implementation, see https://github.com/tensorflow/tensorflow/pull/46894

jakevdp commented 3 years ago

There has been some offline discussion; I think the consensus is that treating dataclasses as pytrees by default is probably not something we want to do, for a few reasons. There is some past experience in TF suggesting this could be problematic: TF recurses into arbitrary data structures which leads to the need for hacky workarounds (e.g. allow/deny in AutoGraph)

Better would be to allow manual registration of dataclasses, perhaps by making JAX's register_pytree_node_class function handle dataclasses without having to manually define tree_flatten and tree_unflatten.

One side-note here: it may be better to only handle frozen dataclasses in this case, because general dataclasses are mutable, which can cause subtle bugs due to side-effects.

jakevdp commented 3 years ago

I haven't thought through all the corner cases, but what if we provided a function similar to register_pytree_node_class that worked for dataclasses; something along the lines of this?

import dataclasses
import jax

def register_pytree_node_dataclass(cls):
  _flatten = lambda obj: jax.tree_flatten(dataclasses.asdict(obj))
  _unflatten = lambda d, children: cls(**d.unflatten(children))
  jax.tree_util.register_pytree_node(cls, _flatten, _unflatten)
  return cls

@register_pytree_node_dataclass
@dataclasses.dataclass
class DataClass:
  a: int
  b: list
  c: dict

d = DataClass(2, ['a', 'b', 'c'], {'x': 1, 'y': 2})
leaves, treedef = jax.tree_flatten(d)
print(treedef.unflatten(leaves))
# DataClass(a=2, b=['a', 'b', 'c'], c={'x': 1, 'y': 2})

Alternatively, we could make register_pytree_node_class provide these default flatten/unflatten functions if the class is a dataclass and does not have them defined.

NathanHowell commented 3 years ago

There has been some offline discussion; I think the consensus is that treating dataclasses as pytrees by default is probably not something we want to do, for a few reasons. There is some past experience in TF suggesting this could be problematic: TF recurses into arbitrary data structures which leads to the need for hacky workarounds (e.g. allow/deny in AutoGraph)

Do you have more details about this hacky workaround? Is it just for namedtuple? I can understand not wanting to recurse into arbitrary classes.

Better would be to allow manual registration of dataclasses, perhaps by making JAX's register_pytree_node_class function handle dataclasses without having to manually define tree_flatten and tree_unflatten.

I'd like all of my dataclasses to always be registered. If you have ideas for how to do this without manually registering each one, or using jax specific decorators, I think that would be a reasonable compromise. It seems that this should be enforced for namedtuple too if this is a legitimate concern.

One side-note here: it may be better to only handle frozen dataclasses in this case, because general dataclasses are mutable, which can cause subtle bugs due to side-effects.

Agreed, too bad they are mutable by default.

jakevdp commented 3 years ago

Currently the pytree registry is implemented as a static mapping from a Python type to a flatten/unflatten pair. Since standard dataclasses are not of a unified type, there is no way to register all possible dataclasses automatically unless we change the registry mechanism in the pytree source.

I suspect the best compromise we can hope for at the moment is explicit registration.

hawkinsp commented 3 years ago

One note is that we do have some more heuristic logic for identifying named tuples:

https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/python/pytree.cc;l=101

jakevdp commented 3 years ago

@hawkinsp - what would you think about pytree exposing a boolean flag (False by default) that optionally registers dataclasses?

hawkinsp commented 3 years ago

I think a flag would be confusing. Whatever behavior we have should be consistent.

NathanHowell commented 3 years ago

@hawkinsp yeah, the same heuristic is used in a couple places. I lifted it out to a function https://github.com/tensorflow/tensorflow/pull/46894/files#diff-f29d52c716a11de00124e98f74e14994b626a81cdde917580fb1ec16a54f0ab2R134-R138

@jakevdp so I can just write an importlib hook that will register every dataclass, which effectively removes this limitation. so that's fine, though I know many others would like this functionality by default (hence this issue) - it's better developer ergonomics. I'd still like to know why this is acceptable for namedtuple but not dataclasses. these seem equivalent to me (and have type annotations), however one is much nicer to write:

class Foo(namedtuple(field_names="test", typename="Foo")):
    test: str

@dataclass(frozen=True)
class Bar:
    test: str
jakevdp commented 3 years ago

it's better developer ergonomics

It's better developer ergonomics for people who want dataclasses registered; worse for people who don't

I'd still like to know why this is acceptable for namedtuple but not dataclasses

I think many people would prefer namedtuple to not be registered automatically, but that would be a fairly significant breaking change at this point so it's unlikely to happen.

NathanHowell commented 3 years ago

@jakevdp ok, fair enough- but what is the buggy behavior we're trying to avoid? "internal discussions" aren't helpful for those of us on GitHub. it seems odd that undesirable functionality will be halfway supported. maybe ya'll can document it here and close the related issues as won't fix.

hawkinsp commented 3 years ago

@jheek might want to comment here.

We probably could look into unregistering NamedTuple by default, especially now that dataclasses are more common.

NeilGirdhar commented 3 years ago

I've been following the issue for a while. I prefer the drop-in replacement idea like the one suggested by @tomhennigan. This would allow an easy, explicit way to mark fields as static or non-static.

Marking fields as static cannot be automatically done based on the types of the elements. For example, an integer element might or might not be able to be static. If the integer is the result of a tracer, it needs to be non-static, but if it's used as the limit of a scan, it has to be static. There is no way to know at definition time.

I ended up forking from flax too. To mark attributes as static, I went with using a modified field constructor, but I named the parameter static.

I like @hawkinsp suggestion of deregistering NamedTuple. People have been using it as a lazy way (for example, optax does it here) to create aggregate structure pytrees. I think it's really ugly because what they really want are dataclasses, and they're exposing iteration on structures that aren't meant to be iterated.

jheek commented 3 years ago

I'd still like to know why this is acceptable for namedtuple but not dataclasses

named tuples are Iterable and are a subclass of tuple. One could argue it's an implementation detail that PyTree's handling of tuple works only on the tuple type and not on any subclasses. Not handling NamedTuples as tuples would create an inconsistency with how most api's deals with tuples and Iterable because those will handle NamedTuples as ordinary tuples. But I think the real problem with NamedTuple is that inheriting from tuple is more a buggy side-effect than a feature. With dataclass we have a better alternative and most use cases of NamedTuple should disappear except maybe when there is truly an order in the fields (like a NamedSequential structure for example). But I guess deregistering could help push adoption the right pattern.

I think registering dataclasses by default (with a flag or always) does lead to problems. It might be worth considering allowing dataclasses to be registered post-hoc e.g.:

@jax.tree_util.dataclass
class Foo:

# equivalent too
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Foo:

Having a field decorator for defining metadata can be really useful at times. In JAX it often avoids the need to define functions and flag like arguments in static_argnums.

A non-registered dataclass can also really help. One thing that comes to mind is configs. They often contain bools, integers, and floats so they can often be raised into jax but they are intended to be constants and allow for optimisations like removing dead branches and avoiding removing multiply by 0 (for example weight_decay=0).

carlosgmartin commented 1 year ago

@jakevdp

Alternatively, we could make register_pytree_node_class provide these default flatten/unflatten functions if the class is a dataclass and does not have them defined.

This sounds like a great idea. I've opened a feature request here: https://github.com/google/jax/issues/15655.

mishmish66 commented 11 months ago

I think registering dataclasses by default (with a flag or always) does lead to problems. It might be worth considering allowing dataclasses to be registered post-hoc e.g.:

@jax.tree_util.dataclass
class Foo:

# equivalent too
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Foo:

Crashing the party to give my opinion, adding a new decorator seems the most convenient to me, I would love to have this functionality. Using a decorator that is separate from the usual register_pytree_node_class one would allow for a syntax like @register_dataclass(static_fieldnames=["name", "quest", "favorite_color"]) which I think would be really cool to have, although that would require some silliness with different return types on different branches but looks aesthetic to me.

tomhennigan commented 11 months ago

@mishmish66 you might find the dataclasses utility functions in chex useful.

# Convenience decorator matching @dataclasses.dataclass:

@chex.dataclass
class MyChexDataclass:
  foo: PyTree[jax.Array]

# You can also register regular dataclasses:

@dataclasses.dataclass
class MyRegularDataclass:
  foo: PyTree[jax.Array]

chex.register_dataclass_type_with_jax_tree_util(MyRegularDataclass)
mishmish66 commented 11 months ago

@mishmish66 you might find the dataclasses utility functions in chex useful.

# Convenience decorator matching @dataclasses.dataclass:

@chex.dataclass
class MyChexDataclass:
  foo: PyTree[jax.Array]

# You can also register regular dataclasses:

@dataclasses.dataclass
class MyRegularDataclass:
  foo: PyTree[jax.Array]

chex.register_dataclass_type_with_jax_tree_util(MyRegularDataclass)

This whole chex tool is great, thanks for the suggestion I'll definitely be using this!