Open dlwh opened 1 year ago
also
File "/home/dlwh/levanter/src/levanter/trainer.py", line 204, in initial_state
model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
File "/home/dlwh/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 333, in f
out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
File "/home/dlwh/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 679, in _array_mlir_constant_handler
return mlir.ir_constants(val._value,
File "/home/dlwh/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 524, in _value
raise RuntimeError("Fetching value for `jax.Array` that spans "
RuntimeError: Fetching value for `jax.Array` that spans non-addressable devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` for this use case.
This one is because arrays snuck into a closure that shouldn't have.
I am having the same issue.
Is there anyway we can debug this error in a jitted function?
Can you give me a reproducer? This shouldn’t be happening inside jit.
On Mon, Jul 1, 2024 at 2:43 AM Wassim KABALAN @.***> wrote:
I am having the same issue.
Is there anyway we can debug this error in a jitted function?
— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/37#issuecomment-2199699433, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIKWOPPUPWZELPF73FLZKEQF7AVCNFSM6AAAAABKFGQ6TKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCOJZGY4TSNBTGM . You are receiving this because you authored the thread.Message ID: @.***>
I created this discussion with a MWE https://github.com/google/jax/discussions/22212 It is not inside a JIT in this example by in my code I call this shardmap from a jitted function
try again with jax nightly? some improvements were just made.
This one doesn't show up inside jit (by construction) so it's a bit harder to intercept. maybe just a FAQ entry?