google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.38k stars 2.69k forks source link

Add support for attr classes #3908

Open AdrienCorenflos opened 4 years ago

AdrienCorenflos commented 4 years ago

Why not add support for attr classes? They're basically namedtuples under steroids.

Looks to me like it would need a small contrib to the pytree file, mostly at lines:

Then you'd be able to do something like:

@attr.s
class World:
    p: np.ndarray = attr.ib()
    v: np.ndarray = attr.ib()

    @jax.jit
    def step(self, dt):
        a = - 9.8
        v = a * dt
        p = self.v *dt
        return attr.evolve(self, p=p, v=v)

World instances would be recognized as acceptable arguments and outputs of jitted functions.

EDIT: Might even help to do AOT compilation à la numba.

EDIT 2: Additionally you could verify that the methods are pure by enforcing that the class be declared as being frozen.

Originally posted by @AdrienCorenflos in https://github.com/google/jax/issues/1567#issuecomment-658731070

jakevdp commented 4 years ago

Related: #2371

hawkinsp commented 4 years ago

Note that the list of pytree containers is extensible: you can register your own custom classes as pytrees.

I suspect we won't add generic support for making all attr classes pytrees, since it's not obvious that there's a single behavior that makes sense for all attr class instances, but there's nothing to stop you from turning individual attr classes into pytrees by registering an extension where that makes sense.

Perhaps you could write a small utility for registering an attr class as a Pytree and either share it here or in another github repository? I suspect that would mostly solve the problem.

AdrienCorenflos commented 4 years ago

Sure can do, though I would tend to argue that any behaviour that makes sense for namedtuples or even dictionaries makes sense for attr classes too.

On Mon, 3 Aug 2020, 21:19 Peter Hawkins, notifications@github.com wrote:

Note that the list of pytree containers is extensible: you can register your own custom classes as pytrees.

I suspect we won't add generic support for making all attr classes pytrees, since it's not obvious that there's a single behavior that makes sense for all attr class instances, but there's nothing to stop you from turning individual attr classes into pytrees by registering an extension where that makes sense.

Perhaps you could write a small utility for registering an attr class as a Pytree and either share it here or in another github repository? I suspect that would mostly solve the problem.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3908#issuecomment-668170669, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEYGFZ3YWQMUCLUQSM5GFE3R635T7ANCNFSM4PNDC45Q .

bionicles commented 4 years ago

It would be useful to register leaves, for example, OpenAI gym is popular, but one can't JIT the 'step' because envs and spaces classes are not jax types, if you register the gym spaces with tree_util, they become internal nodes of the tree, and not leaves, so you wind up mapping over different data types of a space (instead of seeing a 'Box' space all together, you see 'low' 'high' 'shape' 'dtype' independently) ... all these nodes within Box are valid leaves, so the Box ought to be valid too, but there's no way to tell Jax, "here's a custom leaf struct" which works with jit ( or if there is, I haven't found it yet! )

@mattjj how hard would it be to make a leaf registry?

jekbradbury commented 4 years ago

It sounds like you want a type to be a leaf for the purposes of tree_map, but a node for the purposes of jit (as in, you want to wrap each individual array in the Box in a tracer; it doesn't make sense to wrap the whole Box in a tracer because it wouldn't have a shape). Unfortunately, the pytree infrastructure is built primarily for jit (and other JAX API transformations), and functions like tree_map are utilities that are helpful to the extent that you want to do something similar to what JAX does in those entrypoints.

Can you use a different tree library for tree_map-type things, e.g. dm-tree, and register your type as a node for JAX?

jekbradbury commented 4 years ago

Actually I think we can also do something like map_up_to a particular type that you want to temporarily treat as a leaf for a particular tree_map?

ilemhadri commented 2 years ago

I suspect we won't add generic support for making all attr classes pytrees, since it's not obvious that there's a single behavior that makes sense for all attr class instances, but there's nothing to stop you from turning individual attr classes into pytrees by registering an extension where that makes sense.

I agree. From my perspective, attr is classes on steroids as opposed to namedtuples on steroids. Since there is no general jax support for classes, i would expect the same for classes).

if @jekbradbury agrees, maybe we can close this issue!