Open sokol11 opened 3 weeks ago
Hi. I tried running some of the examples on a multi-TPU machine and it throws this "assert not ragged" error: https://github.com/google/jax/issues/13931
The error occurs even when telling jit to use a single device, like so: x, state = jax.jit(strategy.ask, device=jax.devices("tpu")[0])(rng_ask, state)
x, state = jax.jit(strategy.ask, device=jax.devices("tpu")[0])(rng_ask, state)
Running without jit, e.g., x, state = strategy.ask(rng_ask, state) causes input shape errors, which I guess is somewhat expected.
x, state = strategy.ask(rng_ask, state)
Everything runs fine on a single-device machine.
Hi. I tried running some of the examples on a multi-TPU machine and it throws this "assert not ragged" error: https://github.com/google/jax/issues/13931
The error occurs even when telling jit to use a single device, like so:
x, state = jax.jit(strategy.ask, device=jax.devices("tpu")[0])(rng_ask, state)
Running without jit, e.g.,
x, state = strategy.ask(rng_ask, state)
causes input shape errors, which I guess is somewhat expected.Everything runs fine on a single-device machine.