patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.05k stars 135 forks source link

Alternative Sharing #598

Open bowlingmh opened 10 months ago

bowlingmh commented 10 months ago

I was looking at how sharing was added into equinox, and was a bit surprised at the approach taken. I had a similar challenge where I had a PyTree with shared jax.Arrays, and jit'ing a function that took such a PyTree as input ended up being really slow (presumably as it copied the same array over and over on device).

I took a different approach, and I'm wondering if this might be a better or alternative choice for equinox.

Obviously a PyTree can represent a DAG, as you can simply have two sub-PyTrees be, in fact, the same object. The problem is that running that through a jax transformation (namely jit) treats the repeated objects as independent. So why not use the same partition/combine equinox idea to do an unshare/share separation around the problematic jax transformations. Here unshare removes duplication, using a placeholder to store the duplication, and share returns the duplication on the inside of the transformation.

Here's what I implemented to solve the problem. The key is a meta-decorator share_through that does this unshare/share around some other transformation decorator. In my case, I use it as @share_through(eqx.filter_jit) on top of a function I want to jit that takes PyTree's with shared structure.

class _ShareIndex(int):
  def __repr__(self):
    return f'*{int(self)}*'

def _tree_unshare(x):
  leaves, treedef = jtu.tree_flatten(x)

  ids = [ id(l) for l in leaves ]

  shared_leaves = []
  shared_ids = {}
  for l in leaves:
    if id(l) in shared_ids:
      shared_leaves.append(_ShareIndex(shared_ids[id(l)]))
    else:
      shared_leaves.append(l)
      shared_ids[id(l)] = len(shared_leaves) - 1

  return jtu.tree_unflatten(treedef, shared_leaves)

def _tree_share(x):
  leaves, treedef = jtu.tree_flatten(x)

  leaves = [leaves[l] if isinstance(l, _ShareIndex) else l for l in leaves]

  return jtu.tree_unflatten(treedef, leaves)

def share_through(dec):
  def _dec(f):
    def _f_inner(*args, **kwargs):
      args = tuple(_tree_share(a) for a in args)
      kwargs = {k: _tree_share(v) for k,v in kwargs.items()}
      return _tree_unshare(f(*args, **kwargs))

    _dec_f = dec(_f_inner)

    @functools.wraps(f)
    def _f_outer(*args, **kwargs):
      args = tuple(_tree_unshare(a) for a in args)
      kwargs = {k: _tree_unshare(v) for k,v in kwargs.items()}
      return _tree_share(_dec_f(*args, **kwargs))

    return _f_outer

  return _dec

Here, you don't need to specifically label anything as shared. No special construction or labelling of the PyTree. No extra programmatic syntax to manipulate the shared PyTrees. You can specify sharing by simply sharing, and if you don't want something shared make sure it's a copy and not the same object. This seems to fit equinox's ethos perfectly. So maybe this just be built into eqx.filter_jit.

The above is limited to sharing at leaves (which is where all the dynamic arrays will be anyway), but it could relatively easily be modified to allow shared internal subtrees where stubs encode paths to where the original subtree lives.

Another limitation is that two PyTrees with two different sharing structures (e.g., one has shared leaves, and the other doesn't) will result in a re-JIT, as the input signature would be tied to its sharing structure. But overall, this seems like desirable behavior.

Maybe there's other limitations that prevent this from handling all that one would want sharing to handle?

Thoughts?

I could build this into a PR for filter_jit, but don't want to take the time if there's no interest.

patrick-kidger commented 10 months ago

Hey there!

So (variants of) what you have here are actually something I considered as well. In the end I decided against it.

The main reason for this is that right now, there is no coupling between eqx.Module and eqx.filter_{jit, ...}. The former is just a PyTree; the latter is just a transformation that acts on PyTrees. Adding sharing behaviour in the manner you propose would mean we now have a new concept (_ShareIndex) which couples these two things together.

This kind of coupling is something that we've carefully avoided so far in Equinox, on the basis that it makes Equinox code easier to reason about, and easier to interchange with non-Equinox JAX code. (For example, your proposal would mean forcing the use of eqx.filter_jit over jax.jit for this use-case.)

There are a handful of other considerations that pop up here too, like the extra runtime overhead of unsharing before calling the transformation, and the possibility of silently wrong behaviour (no sharing) if you forget to use the share_through version of a transformation.

I hope that helps!

bowlingmh commented 10 months ago

Hi Patrick,

Huh. I thought my approach actually fully respected the zero coupling, as it's one of my favourite aspects of equinox. The _ShareIndex isn't intended to be a user facing concept at all, it's one internal implementation for how to squeeze shared parameters through the jit boundary. From the user side, one does sharing the most natural way possible: simply having the same jax array appear multiple times in the PyTree. And it doesn't matter if the PyTree is a eqx.Module or just standard python PyTree types.

But maybe I am misreading your comment.

Here's an alternative way to understand the difference... Your approach implements sharing within the eqx.Module class and means you can be indifferent to using eqx.filter_jit or jax.jit to see sharing work. My suggestion is to implement sharing within the *jit decorator and means you can be indifferent to using eqx.Module or other generic PyTree structures. Does that sound right? And while I don't think my suggestion ties eqx.Module to eqx.filter_jit it does tie one to using eqx.filter_jit if you want sharing, and I can see that being undesirable. Especially, as you note, it would potentially incur unnecessary overhead when sharing is not being used (although that could be on a flag). I can also see the argument that eqx.filter_jit's concern is about handling static fields through the jit boundary, and it was never intended to be about sharing. So separation of concerns suggests it shouldn't be handled there.

I guess my problem is that my other favourite aspect of equinox is that I don't have to do awkward unnatural things to make simple natural things work, such as trying to keep my parameters cleanly separate from code. That felt refreshing and freeing. I think this is what made me sad about eqx.nn.shared, which feels like it adds awkward unnatural access to parameters.

Really my complaint lies with jax (why wouldn't jit allow for leaf-sharing in the PyTree?). I feel like equinox has so much been about making jax act more sane to me. So, thank you for that! Even if there will be areas that it still feels unnecessarily awkward.

Thanks for your time and in your work on equinox!

patrick-kidger commented 9 months ago

Ah you're right, I've mispoken a little: the coupling wouldn't be between eqx.Module and eqx.filter_jit, but between sharing and eqx.filter_jit.

One problem this highlights is what happens when you nest multiple transformations, which would also need to respect sharing. You may remember to use eqx.filter_jit, but a later uncareful usage of jax.grad could silently de-share things.

Really my complaint lies with jax (why wouldn't jit allow for leaf-sharing in the PyTree?).

Right! In large part this is about working with the constraints of JAX's model of computation itself.

Although for what it's worth -- possibly I'm just too deeply into this now -- I've actually come around to believe that PyTrees are indeed better than "PyDAGs" (=trees with sharing). This handles a few edge cases much more elegantly:

Conversely, I definitely do appreciate that the eqx.nn.Shared notation is a little clunky. That's the price we pay for these nice things I suppose :)