google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.74k stars 2.71k forks source link

Is it possible to use objects with Jax's machine learning library #1967

Closed Montana closed 4 years ago

Montana commented 4 years ago

Using classes works for me as long as I don't refer to any class objects inside my functions, but was wondering if using objects is possible with Jax?

-Montana

jbaron commented 4 years ago

You can use objects, but need to also need register some code how to flatten and unflatten them when passing them to a traceable function. This is however a straight forward process. (some examples can be found here: https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html ).

Alternatively I sometimes use named tuples if I just want to group some parameters in order to increase readability of my code.

Montana commented 4 years ago

Thanks @jbaron, makes sense. Since I'm expecting NamedTuples to work like other Python builtins. I appreciate your response.

-Montana

umangjpatel commented 4 years ago

Can you send me an example to do that? Like creating a class and extending it to work with JAX

mattjj commented 4 years ago

The technique @jbaron suggests works if you want your classes to be treated as containers, isomorphic to tuples (i.e. product types). The functions you register define the isomorphism: one function for how to take an instance and flatten it to an iterable (plus some metadata, like dict keys), and another for how to take the iterable (plus the metadata) and reproduce the a class instance equal to the original one.

Doing that will let instances of your own container-like classes be passed as arguments and returned as values to and from JAX-transformed functions (like jitted functions). After all, JAX will effectively see them as flat tuples, using the functions you provided to do the conversion back and forth.

from jax import jit

class Special:
  def __init__(self, x, y):
    self.x = x
    self.y = y

@jit
def f(special):
  return special.x + special.y

special = Special(1, 2)

f(special)  # TypeError: Argument '<__main__.Special object at 0x7f1403ee5e10>' of type <class '__main__.Special'> is not a valid JAX type

from jax import tree_util
tree_util.register_pytree_node(Special, lambda s: ((s.x, s.y), None), lambda _, xs: Special(xs[0], xs[1]))

f(special)  # 3

Object identity is not preserved (which is a condition of functional purity):

@jit
def g(s1, s2):
  return s1 is s2

g(special, special)  # False

You can use methods on the instances you pass in, but side-effects won't work (they will likely silently fail rather than raising an error).

As @jbaron said, these things get flattened and then unflattened/re-created when being passed into or out of JAX-transformed functions, so if you keep that in mind it might be easier to remember the constraints (no object identity, no side effects).

Since I think @Montana 's original question was answered, and hopefully @umangjpatel 's request for an example was as well, I'll close this issue. Please open others if more questions come up!

Montana commented 4 years ago

This answered my question, thank you!

samskiter commented 1 year ago

I just wanted to highlight, that I don't think that this answer is entirely true - you can't just apply this method and have grad work totally. See this discussion - and my specific reproduction https://github.com/google/jax/discussions/17341

jakevdp commented 1 year ago

I just wanted to highlight, that I don't think that this answer is entirely true - you can't just apply this method and have grad work totally.

To make this more explicit: objects which have impure methods (i.e. those with side-effects – see JAX Sharp Bits: Pure Functions) will not work correctly with jax.grad and other transformations, even if they are registered as pytrees.