stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Add context to help fix `Array`s that are not fully addressable` errors #37

Open dlwh opened 1 year ago

dlwh commented 1 year ago
    raise RuntimeError(
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Running operations on `Array`s that are not fully addressable by this process (i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s very important that all processes run the same cross-process computations in the same order otherwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programming model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this error, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager.

This one doesn't show up inside jit (by construction) so it's a bit harder to intercept. maybe just a FAQ entry?

dlwh commented 1 year ago

also

  File "/home/dlwh/levanter/src/levanter/trainer.py", line 204, in initial_state
    model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
  File "/home/dlwh/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 333, in f
    out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
  File "/home/dlwh/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 679, in _array_mlir_constant_handler
    return mlir.ir_constants(val._value,
  File "/home/dlwh/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 524, in _value
    raise RuntimeError("Fetching value for `jax.Array` that spans "
RuntimeError: Fetching value for `jax.Array` that spans non-addressable devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` for this use case.

This one is because arrays snuck into a closure that shouldn't have.

ASKabalan commented 3 months ago

I am having the same issue.

Is there anyway we can debug this error in a jitted function?

dlwh commented 3 months ago

Can you give me a reproducer? This shouldn’t be happening inside jit.

On Mon, Jul 1, 2024 at 2:43 AM Wassim KABALAN @.***> wrote:

I am having the same issue.

Is there anyway we can debug this error in a jitted function?

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/37#issuecomment-2199699433, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIKWOPPUPWZELPF73FLZKEQF7AVCNFSM6AAAAABKFGQ6TKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCOJZGY4TSNBTGM . You are receiving this because you authored the thread.Message ID: @.***>

ASKabalan commented 3 months ago

I created this discussion with a MWE https://github.com/google/jax/discussions/22212 It is not inside a JIT in this example by in my code I call this shardmap from a jitted function

dlwh commented 3 months ago

try again with jax nightly? some improvements were just made.