Closed neel04 closed 9 months ago
This looks like a bug in XLA:TPU. I'd suggest filing it either on the main JAX repository or on the XLA repo. You'll probably need to find a MWE, though.
I don't understand much of how equinox maintains state - my basic understanding is that the actual state at runtime is held by jax internally and equinox just issues host callbacks to mutate that state as needed - where the Module is just the abstract PyTree representation?
Equinox doesn't maintain any state, actually! It was an important part of Equinox's design that we not go around mutating things. The Module is a PyTree of arrays, and these are explicitly updated, by you, when you do things like gradient descent.
Equinox actually uses callbacks sparingly -- these are the only functions which use them, and they're all fairly uncommon:
eqx.error_if
, eqx.branched_error_if
,eqx.filter_pure_callback
,eqx.debug.store_dce
,eqx.internal.noinline
(and this one isn't even a documented feature :) )are you explicitly using any of these?
are you explicitly using any of these?
Nope. So its probably some constraint on TPUs placed by XLA 😥 I guess I'll try and debug it, but so far hadn't had any luck
@patrick-kidger Turns out, the problem was using eqx.internal.while_loop
, specifically the checkpoint
-ed form. Using the bounded
type of loop works perfectly fine.
I don't know if you feel its worth it to resolve the bug on TPUs 😅 If you have access to them, I can try to make a repro for them.
Ah! Indeed I'd forgotten, eqx.internal.while_loop
uses eqx.error_if
, which then uses a callback.
I'd suggest reporting this as an XLA bug regardless, but it's probably not a bug I can resolve directly.
As a possible workaround, you can try commenting out every error_if
inside the implementation of eqx.internal.while_loop
. I imagine I'll actually do something like that in the next release of Equinox, actually -- this code has now proven itself pretty reliable.
Thanks! I tried commenting out all the callbacks and rebuilt equinox like this and kind='checkpointed
works pretty well now.
I talked to James Bradbury on twitter, and he said its a bit harder to work with equinox as host callbacks as "messy" and are closer to a "hack" so its likely that its clashing with some TPU specific optimizations built directly into XLA. I guess it'd require quite a bit of surgery to fix this bug - so it might be some time before its fixed 🙂
Until that's fixed, I suppose in the next release, maybe you could expose some flag for internal.while_loop
that optionally disables host callbacks if the user desires (atleast the non-critical ones). It might go against the equinox
philosophy of simplicity I suppose, but IMO its fine because its already an internal
application to users are more likely to be careful when using it and (hopefully) read all the documentation.
Again, thanks for everything and providing such a lovely library ❤️ and have a good weekend!
Ah, marvellous! I'm glad that's working for you. I've just written #631 to fix this up for the next release.
Since it is specifically error_if
that is causing you problems, then one other option worth knowing is the environment variable EQX_ON_ERROR=nan
. (Documentation here.) This will disable every eqx.error_if
used anywhere in your program. This is really intended as a "debug vs release mode" optimisation -- to remove the checks once you're satisfied that your program should always work. But it could also help with this issue.
Looks like using EQX_ON_ERROR=nan
works pretty well here and resolves the issue 😄
TPU_DEBUGGING
flag that in turn would switch on and off the various flags to minimize collisions with XLA, since its a bit of a black box that we usually need to work around...Thanks for everything again and have a great weekend!
I'm training a custom arch of mine, and had a usecase where I wanted to perform 2 (different) forward passes which have a different computational graph. I wanted to take the outputs from both flows, and evaluate an aggregated loss.
But apparently, if I compute two branches, I get the below error.
Traceback
Code
I don't have a repro unfortunately. But my codeflow looks like this:
This is a bit convoluted, but the core is that there are the 2 different forward passes which explicitly depend on the same
model
and is reutilized here, with slightly different arguments (mainlyprev_thought
)Because the error occurs only when both of the flows are present - either through using
jax.lax.cond
to dynamically switch between both or simply aggregating outputs from both forward passes simultaneously, the common problem seems to be whenXLA
is unable to handle both computational flows.(Note:
jax.lax.cond
is lowered toselect
whenvmap
-ed, which is why both flows do end up getting computed too)This error is triggered only on TPUs, not on GPUs so perhaps it might just turn out to be a limitation of
equinox
. I don't understand much of howequinox
maintains state - my basic understanding is that the actual state at runtime is held byjax
internally andequinox
just issues host callbacks to mutate that state as needed - where theModule
is just the abstractPyTree
representation?The error kindof sounds like
equinox
issued multiple host callbacks and they collide. Why its only a problem on TPUs specifically, could be down to TPU-specific optimizations of XLA.Would you have any idea regarding this?