google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.8k stars 2.72k forks source link

@differentiable decorator for classes #1808

Open cgarciae opened 4 years ago

cgarciae commented 4 years ago

Hey, inspired by S4TF where regular structs can be differentiable while specifying which fields should be differentiated, I implemented this @differentiable attribute which registers the type in register_pytree_node and defines its proper tangent structure. It works like this:

Use the @differentiable decorator on your class and specify which attributes are differentiable using type hints:

@differentiable
@dataclass
class Model:
    a: int
    b: differentiable
    c: differentiable
    d: str = "value"

Only the fields tagged with differentiable or other types that where decorated with @differentiable are considered. Here a @dataclass is used but it would work with regular classes as well. Now that its registered, you can differentiate through it:

def loss(model):
    return 3 * np.sum(model.b) + 5 * np.sum(model.c)
dloss = grad(loss)

model = Model(a=1, b=np.ones((1,)), c=np.ones((2,)))
dmodel = dloss(model)

Each class decorated with @differentiable gets a tangent class implemented for it, so here dmodel is of type ModelTangent (which can be accessed via Model.Tangent). In this example dmodel would look like this:

ModelTangent(
    b=DeviceArray([3.], dtype=float32), 
    c=DeviceArray([5., 5.], dtype=float32)
)

I think that the move method described in Differentiable data structures could be also be automatically implemented for classes that use this decorator.

Is this of interest to jax? Should I try to make a PR?

cgarciae commented 4 years ago

Here is the current code:

from jax.tree_util import register_pytree_node
from collections import namedtuple

def differentiable(cls):

    cls.__differentiable__ = True

    if hasattr(cls, "__dataclass_fields__"):
        fields = {name: value.type for name, value in cls.__dataclass_fields__.items()}
    elif hasattr(cls, "__annotations__"):
        fields = cls.__annotations__.copy()

    diff_fields = [
        name
        for name, type_ in fields.items()
        if type_ == differentiable or getattr(type_, "__differentiable__", False)
    ]

    cls.Tangent = namedtuple(f"{cls.__name__}Tangent", diff_fields)

    def flatten(self):
        diff_vars = [getattr(self, name) for name in diff_fields]

        return diff_vars, None

    def unflatten(_self, diff_vars):
        return cls.Tangent(*diff_vars)

    register_pytree_node(cls, flatten, unflatten)

    return cls
shoyer commented 4 years ago

This is an interesting idea, but I’m a concerned that your current implementation would break round trips for pytree flatten/unflatten. JAX ends up doing this internally in most of the higher level control flow functions (jvp, scan, etc).

What about simply adding support for data classes in the pytree module, possibly with some introspection of annotations to distinguish between static and array-like arguments? I’m not even certain that we need the separate differentiable decorator. It seems like just introspecting types and seeing if they are registered as pytrees or arrays could be enough.

cgarciae commented 4 years ago

Hey @shoyer, a couple of questions/comments:

  1. Can you point me to where this is implemented? I am just starting to get familiar with the library.
  2. We would still need to typehints to filter e.g. in batchnorm you don't want the gradient through the global statistics.
  3. What kind of patterns break control flow?
shoyer commented 4 years ago

The core pytree logic is in C++ now: https://github.com/google/jax/blob/master/jaxlib/pytree.cc Also see the Python wrapper: https://github.com/google/jax/blob/master/jax/tree_util.py

I don't fully understand your usescases for filtering gradients. I'm sure they are valid, but they might need a different sort of solution, e.g., perhaps optimizers should be aware of different types of variables, somehow?

Here's an example of using a trivial lax.scan to write the identity function. When run on your current example, it breaks as shown below:

import jax
from jax import lax
import jax.numpy as np

def scanned_identity(x, repeats=1):
  """Loop over an object with scan and return it."""
  out, _ = lax.scan(
      lambda carray, x: (carray, ()),
      x,
      np.arange(repeats),
  )
  return out

print(scanned_identity({'a': 1, 'b': 2}))
# {'a': DeviceArray(1, dtype=int32), 'b': DeviceArray(2, dtype=int32)}

from dataclasses import dataclass

@differentiable
@dataclass
class Model:
    a: int
    b: differentiable
    c: differentiable
    d: str = "value"

model = Model(a=1, b=np.ones((1,)), c=np.ones((2,)))
print(scanned_identity(model))
# TypeError: scan carry output and input must have same type structure, got PyTreeDef(namedtuple[<class '__main__.ModelTangent'>], [*,*]) and PyTreeDef(<class '__main__.Model'>[None], [*,*]).
mattjj commented 4 years ago

I haven't parsed the details here yet, but I think there is a use case (based on conversations long ago with rxwei, jekbradbury, sharadmv, and dougalm) where one wants to differentiate with respect to a container (ie a product type) that contains, say, a float and an int. That's potentially different from using the pytree mechanism to shuttle off the int part, since we need to pay attention to that int part for, say, jit.

We likely need to teach the AD system that the tangent space for integral values is core.unit, modeling the trivial vector space. (Alternatively we could have the differentiation api handle pytrees differently. Dynamic typing gives us a lot of options!)

Sorry if the above is cryptic. I wanted to jot down some quick thoughts without yet delving into the excellent points made in this thread already.

cgarciae commented 4 years ago

@shoyer I see, I could change so that the output structure is not a Tangent but a copy of the original (Edit: this doesn't work). I didn't do this because in S4TF they have this structure called AllDiferentiableVariables and I thought I might be useful here.

Like @mattjj is saying, the use-case is to be able to differentiate wrt container types where one could have some parts that are differentiable, others that are not (integers, strings, etc) and others that are differentiable but one doesn't want the gradient to modify them. Look at S4TF's implementation of BatchNormalization here you don't want the algorithm to differentiate through the hyperparameters or the statistics but still want the Layer learnable parts to be differentiable.

I think it would be nice to be able to keep hyperparams, parameters, and other kind of information in the same structure. Stax tends to split the parameters from the layer, which I take is nice because of its functional properties, but its a bit less familiar with respect to how pytorch, keras or S4TF do it. I think S4TF has the cleanest solution but jax could be on par easily.