RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
480 stars 44 forks source link

Issue when running on multiple devices #73

Open sokol11 opened 3 weeks ago

sokol11 commented 3 weeks ago

Hi. I tried running some of the examples on a multi-TPU machine and it throws this "assert not ragged" error: https://github.com/google/jax/issues/13931

The error occurs even when telling jit to use a single device, like so: x, state = jax.jit(strategy.ask, device=jax.devices("tpu")[0])(rng_ask, state)

Running without jit, e.g., x, state = strategy.ask(rng_ask, state) causes input shape errors, which I guess is somewhat expected.

Everything runs fine on a single-device machine.