Open cgarciae opened 4 years ago
Here is the current code:
from jax.tree_util import register_pytree_node
from collections import namedtuple
def differentiable(cls):
cls.__differentiable__ = True
if hasattr(cls, "__dataclass_fields__"):
fields = {name: value.type for name, value in cls.__dataclass_fields__.items()}
elif hasattr(cls, "__annotations__"):
fields = cls.__annotations__.copy()
diff_fields = [
name
for name, type_ in fields.items()
if type_ == differentiable or getattr(type_, "__differentiable__", False)
]
cls.Tangent = namedtuple(f"{cls.__name__}Tangent", diff_fields)
def flatten(self):
diff_vars = [getattr(self, name) for name in diff_fields]
return diff_vars, None
def unflatten(_self, diff_vars):
return cls.Tangent(*diff_vars)
register_pytree_node(cls, flatten, unflatten)
return cls
This is an interesting idea, but I’m a concerned that your current implementation would break round trips for pytree flatten/unflatten. JAX ends up doing this internally in most of the higher level control flow functions (jvp, scan, etc).
What about simply adding support for data classes in the pytree module, possibly with some introspection of annotations to distinguish between static and array-like arguments? I’m not even certain that we need the separate differentiable
decorator. It seems like just introspecting types and seeing if they are registered as pytrees or arrays could be enough.
Hey @shoyer, a couple of questions/comments:
The core pytree logic is in C++ now: https://github.com/google/jax/blob/master/jaxlib/pytree.cc Also see the Python wrapper: https://github.com/google/jax/blob/master/jax/tree_util.py
I don't fully understand your usescases for filtering gradients. I'm sure they are valid, but they might need a different sort of solution, e.g., perhaps optimizers should be aware of different types of variables, somehow?
Here's an example of using a trivial lax.scan
to write the identity function. When run on your current example, it breaks as shown below:
import jax
from jax import lax
import jax.numpy as np
def scanned_identity(x, repeats=1):
"""Loop over an object with scan and return it."""
out, _ = lax.scan(
lambda carray, x: (carray, ()),
x,
np.arange(repeats),
)
return out
print(scanned_identity({'a': 1, 'b': 2}))
# {'a': DeviceArray(1, dtype=int32), 'b': DeviceArray(2, dtype=int32)}
from dataclasses import dataclass
@differentiable
@dataclass
class Model:
a: int
b: differentiable
c: differentiable
d: str = "value"
model = Model(a=1, b=np.ones((1,)), c=np.ones((2,)))
print(scanned_identity(model))
# TypeError: scan carry output and input must have same type structure, got PyTreeDef(namedtuple[<class '__main__.ModelTangent'>], [*,*]) and PyTreeDef(<class '__main__.Model'>[None], [*,*]).
I haven't parsed the details here yet, but I think there is a use case (based on conversations long ago with rxwei, jekbradbury, sharadmv, and dougalm) where one wants to differentiate with respect to a container (ie a product type) that contains, say, a float and an int. That's potentially different from using the pytree mechanism to shuttle off the int part, since we need to pay attention to that int part for, say, jit
.
We likely need to teach the AD system that the tangent space for integral values is core.unit
, modeling the trivial vector space. (Alternatively we could have the differentiation api handle pytrees differently. Dynamic typing gives us a lot of options!)
Sorry if the above is cryptic. I wanted to jot down some quick thoughts without yet delving into the excellent points made in this thread already.
@shoyer I see, I could change so that the output structure is not a Tangent
but a copy of the original (Edit: this doesn't work). I didn't do this because in S4TF they have this structure called AllDiferentiableVariables
and I thought I might be useful here.
Like @mattjj is saying, the use-case is to be able to differentiate wrt container types where one could have some parts that are differentiable, others that are not (integers, strings, etc) and others that are differentiable but one doesn't want the gradient to modify them. Look at S4TF's implementation of BatchNormalization here you don't want the algorithm to differentiate through the hyperparameters or the statistics but still want the Layer
learnable parts to be differentiable.
I think it would be nice to be able to keep hyperparams, parameters, and other kind of information in the same structure. Stax tends to split the parameters from the layer, which I take is nice because of its functional properties, but its a bit less familiar with respect to how pytorch, keras or S4TF do it. I think S4TF has the cleanest solution but jax
could be on par easily.
Hey, inspired by S4TF where regular structs can be differentiable while specifying which fields should be differentiated, I implemented this
@differentiable
attribute which registers the type inregister_pytree_node
and defines its proper tangent structure. It works like this:Use the
@differentiable
decorator on your class and specify which attributes are differentiable using type hints:Only the fields tagged with
differentiable
or other types that where decorated with@differentiable
are considered. Here a@dataclass
is used but it would work with regular classes as well. Now that its registered, you can differentiate through it:Each class decorated with
@differentiable
gets a tangent class implemented for it, so heredmodel
is of typeModelTangent
(which can be accessed viaModel.Tangent
). In this exampledmodel
would look like this:I think that the
move
method described in Differentiable data structures could be also be automatically implemented for classes that use this decorator.Is this of interest to
jax
? Should I try to make a PR?