Open bowlingmh opened 10 months ago
Hey there!
So (variants of) what you have here are actually something I considered as well. In the end I decided against it.
The main reason for this is that right now, there is no coupling between eqx.Module
and eqx.filter_{jit, ...}
. The former is just a PyTree; the latter is just a transformation that acts on PyTrees. Adding sharing behaviour in the manner you propose would mean we now have a new concept (_ShareIndex
) which couples these two things together.
This kind of coupling is something that we've carefully avoided so far in Equinox, on the basis that it makes Equinox code easier to reason about, and easier to interchange with non-Equinox JAX code. (For example, your proposal would mean forcing the use of eqx.filter_jit
over jax.jit
for this use-case.)
There are a handful of other considerations that pop up here too, like the extra runtime overhead of unsharing before calling the transformation, and the possibility of silently wrong behaviour (no sharing) if you forget to use the share_through
version of a transformation.
I hope that helps!
Hi Patrick,
Huh. I thought my approach actually fully respected the zero coupling, as it's one of my favourite aspects of equinox. The _ShareIndex
isn't intended to be a user facing concept at all, it's one internal implementation for how to squeeze shared parameters through the jit boundary. From the user side, one does sharing the most natural way possible: simply having the same jax array appear multiple times in the PyTree. And it doesn't matter if the PyTree is a eqx.Module
or just standard python PyTree types.
But maybe I am misreading your comment.
Here's an alternative way to understand the difference... Your approach implements sharing within the eqx.Module
class and means you can be indifferent to using eqx.filter_jit
or jax.jit
to see sharing work. My suggestion is to implement sharing within the *jit
decorator and means you can be indifferent to using eqx.Module
or other generic PyTree structures. Does that sound right? And while I don't think my suggestion ties eqx.Module
to eqx.filter_jit
it does tie one to using eqx.filter_jit
if you want sharing, and I can see that being undesirable. Especially, as you note, it would potentially incur unnecessary overhead when sharing is not being used (although that could be on a flag). I can also see the argument that eqx.filter_jit
's concern is about handling static fields through the jit boundary, and it was never intended to be about sharing. So separation of concerns suggests it shouldn't be handled there.
I guess my problem is that my other favourite aspect of equinox is that I don't have to do awkward unnatural things to make simple natural things work, such as trying to keep my parameters cleanly separate from code. That felt refreshing and freeing. I think this is what made me sad about eqx.nn.shared
, which feels like it adds awkward unnatural access to parameters.
Really my complaint lies with jax (why wouldn't jit allow for leaf-sharing in the PyTree?). I feel like equinox has so much been about making jax act more sane to me. So, thank you for that! Even if there will be areas that it still feels unnecessarily awkward.
Thanks for your time and in your work on equinox!
Ah you're right, I've mispoken a little: the coupling wouldn't be between eqx.Module
and eqx.filter_jit
, but between sharing and eqx.filter_jit
.
One problem this highlights is what happens when you nest multiple transformations, which would also need to respect sharing. You may remember to use eqx.filter_jit
, but a later uncareful usage of jax.grad
could silently de-share things.
Really my complaint lies with jax (why wouldn't jit allow for leaf-sharing in the PyTree?).
Right! In large part this is about working with the constraints of JAX's model of computation itself.
Although for what it's worth -- possibly I'm just too deeply into this now -- I've actually come around to believe that PyTrees are indeed better than "PyDAGs" (=trees with sharing). This handles a few edge cases much more elegantly:
state_dict
of a PyTorch model (which is a DAG, not a tree) will actually duplicate shared parameters when they appear in multiple places. eqx.nn.MLP
, which can take a single activation function and from this broadcast out a separate copy of it for each neuron.Conversely, I definitely do appreciate that the eqx.nn.Shared
notation is a little clunky. That's the price we pay for these nice things I suppose :)
I was looking at how sharing was added into equinox, and was a bit surprised at the approach taken. I had a similar challenge where I had a PyTree with shared jax.Arrays, and jit'ing a function that took such a PyTree as input ended up being really slow (presumably as it copied the same array over and over on device).
I took a different approach, and I'm wondering if this might be a better or alternative choice for equinox.
Obviously a PyTree can represent a DAG, as you can simply have two sub-PyTrees be, in fact, the same object. The problem is that running that through a jax transformation (namely jit) treats the repeated objects as independent. So why not use the same partition/combine equinox idea to do an unshare/share separation around the problematic jax transformations. Here unshare removes duplication, using a placeholder to store the duplication, and share returns the duplication on the inside of the transformation.
Here's what I implemented to solve the problem. The key is a meta-decorator
share_through
that does this unshare/share around some other transformation decorator. In my case, I use it as@share_through(eqx.filter_jit)
on top of a function I want to jit that takes PyTree's with shared structure.Here, you don't need to specifically label anything as shared. No special construction or labelling of the PyTree. No extra programmatic syntax to manipulate the shared PyTrees. You can specify sharing by simply sharing, and if you don't want something shared make sure it's a copy and not the same object. This seems to fit equinox's ethos perfectly. So maybe this just be built into
eqx.filter_jit
.The above is limited to sharing at leaves (which is where all the dynamic arrays will be anyway), but it could relatively easily be modified to allow shared internal subtrees where stubs encode paths to where the original subtree lives.
Another limitation is that two PyTrees with two different sharing structures (e.g., one has shared leaves, and the other doesn't) will result in a re-JIT, as the input signature would be tied to its sharing structure. But overall, this seems like desirable behavior.
Maybe there's other limitations that prevent this from handling all that one would want sharing to handle?
Thoughts?
I could build this into a PR for
filter_jit
, but don't want to take the time if there's no interest.