patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 136 forks source link

[TPU] XLA: Channel is used for multiple host instructions #628

Closed neel04 closed 9 months ago

neel04 commented 9 months ago

I'm training a custom arch of mine, and had a usecase where I wanted to perform 2 (different) forward passes which have a different computational graph. I wanted to take the outputs from both flows, and evaluate an aggregated loss.

But apparently, if I compute two branches, I get the below error.

Traceback

Traceback (most recent call last):
  File "/kaggle/working/ReAct_Jax/train_model.py", line 47, in <module>
    main(key)
  File "/kaggle/working/ReAct_Jax/train_model.py", line 43, in main
    trainer.train(args.epochs, trainloader, valloader, key)
  File "/kaggle/working/ReAct_Jax/ReAct/utils/trainer.py", line 235, in train
    loss, model, opt_state = make_step(model, seq, label, pad_mask, rndm_n, rndm_k,
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/equinox/_module.py", line 875, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/equinox/_jit.py", line 198, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:2330) instructions.size() == 2 channel 11 is used for multiple host send/recv instructions
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Code

I don't have a repro unfortunately. But my codeflow looks like this:

@jax.jit
def n_k_loop(model: eqx.Module, input_arr: Array, pad_mask: Array, n: int, k: int, key: PRNGKeyArray) -> Array:
    key1, key2 = jax.random.split(key, 2)

    # forward pass the model without tracking grads
    _, intermediate_array = model(
        input_arr, n,
        pad_mask=pad_mask,
        prev_thought=None,
        key=key1)

    intermediate_array = jax.lax.stop_gradient(intermediate_array)

    # n-k passes, but track the gradient this time
    output, _ = model(input_arr, k, pad_mask=pad_mask, prev_thought=intermediate_array, key=key2)

    return output

@jax.jit
def k_loop(model: eqx.Module, input_arr: Array, pad_mask: Array, n: int, k: int, key: PRNGKeyArray) -> Array:
    key1, key2 = jax.random.split(key, 2)

    output, _ = model(input_arr, n, pad_mask=pad_mask, prev_thought=None, key=key1)

    return output

This is a bit convoluted, but the core is that there are the 2 different forward passes which explicitly depend on the same model and is reutilized here, with slightly different arguments (mainly prev_thought)

Because the error occurs only when both of the flows are present - either through using jax.lax.cond to dynamically switch between both or simply aggregating outputs from both forward passes simultaneously, the common problem seems to be when XLA is unable to handle both computational flows.

(Note: jax.lax.cond is lowered to select when vmap-ed, which is why both flows do end up getting computed too)

Would you have any idea regarding this?

patrick-kidger commented 9 months ago

This looks like a bug in XLA:TPU. I'd suggest filing it either on the main JAX repository or on the XLA repo. You'll probably need to find a MWE, though.

I don't understand much of how equinox maintains state - my basic understanding is that the actual state at runtime is held by jax internally and equinox just issues host callbacks to mutate that state as needed - where the Module is just the abstract PyTree representation?

Equinox doesn't maintain any state, actually! It was an important part of Equinox's design that we not go around mutating things. The Module is a PyTree of arrays, and these are explicitly updated, by you, when you do things like gradient descent.

Equinox actually uses callbacks sparingly -- these are the only functions which use them, and they're all fairly uncommon:

are you explicitly using any of these?

neel04 commented 9 months ago

are you explicitly using any of these?

Nope. So its probably some constraint on TPUs placed by XLA 😥 I guess I'll try and debug it, but so far hadn't had any luck

neel04 commented 9 months ago

@patrick-kidger Turns out, the problem was using eqx.internal.while_loop, specifically the checkpoint-ed form. Using the bounded type of loop works perfectly fine.

I don't know if you feel its worth it to resolve the bug on TPUs 😅 If you have access to them, I can try to make a repro for them.

patrick-kidger commented 9 months ago

Ah! Indeed I'd forgotten, eqx.internal.while_loop uses eqx.error_if, which then uses a callback.

I'd suggest reporting this as an XLA bug regardless, but it's probably not a bug I can resolve directly.

As a possible workaround, you can try commenting out every error_if inside the implementation of eqx.internal.while_loop. I imagine I'll actually do something like that in the next release of Equinox, actually -- this code has now proven itself pretty reliable.

neel04 commented 9 months ago

Thanks! I tried commenting out all the callbacks and rebuilt equinox like this and kind='checkpointed works pretty well now.

I talked to James Bradbury on twitter, and he said its a bit harder to work with equinox as host callbacks as "messy" and are closer to a "hack" so its likely that its clashing with some TPU specific optimizations built directly into XLA. I guess it'd require quite a bit of surgery to fix this bug - so it might be some time before its fixed 🙂

Until that's fixed, I suppose in the next release, maybe you could expose some flag for internal.while_loop that optionally disables host callbacks if the user desires (atleast the non-critical ones). It might go against the equinox philosophy of simplicity I suppose, but IMO its fine because its already an internal application to users are more likely to be careful when using it and (hopefully) read all the documentation.

Again, thanks for everything and providing such a lovely library ❤️ and have a good weekend!

patrick-kidger commented 9 months ago

Ah, marvellous! I'm glad that's working for you. I've just written #631 to fix this up for the next release.

Since it is specifically error_if that is causing you problems, then one other option worth knowing is the environment variable EQX_ON_ERROR=nan. (Documentation here.) This will disable every eqx.error_if used anywhere in your program. This is really intended as a "debug vs release mode" optimisation -- to remove the checks once you're satisfied that your program should always work. But it could also help with this issue.

neel04 commented 9 months ago

Looks like using EQX_ON_ERROR=nan works pretty well here and resolves the issue 😄

631 looks good - I guess in the future, if more XLA bugs crop up, we could setup a dedicated TPU_DEBUGGING flag that in turn would switch on and off the various flags to minimize collisions with XLA, since its a bit of a black box that we usually need to work around...

Thanks for everything again and have a great weekend!