RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

Replace jax.lax.select with jax.numpy.where. #60

Open carlosgmartin opened 1 year ago

carlosgmartin commented 1 year ago

Fixes https://github.com/RobertTLange/evosax/issues/45#issuecomment-1528573782. See also the jax.lax docs:

Where possible, prefer to use libraries such as jax.numpy instead of using jax.lax directly. The jax.numpy API follows NumPy, and is therefore more stable and less likely to change than the jax.lax API.