RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

Type issue #45

Closed matthewrhysjones closed 1 year ago

matthewrhysjones commented 1 year ago

Hi, thanks for this really cool library! I'm having some issues with type mismatch produced in the EvoState class

My fitness function contains a matrix inversion (implemented with cholesky decompositition) and to try and avoid pesky numerically instability I've enabled 64bit precision at the start of my script with

from jax.config import configconfig.update("jax_enable_x64", True)

when I run evosax, I get the error

TypeError: lax.select requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

The issue is that all of dtypes are float64, except state.best_fitness which is float32. In the EvoState base class, it looks like this precision is fixed, and so explains why jax_enable_x64 doesn't overwrite

@struct.dataclass
class EvoState:
    mean: chex.Array
    archive: chex.Array
    fitness: chex.Array
    sigma: chex.Array
    best_member: chex.Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    gen_counter: int = 0

Would it be possible to advise on best way forward?

Thanks!

brunoramirez12 commented 1 year ago

Hello, I suggest you can address this issue with 'jnp.float32()' to cast the 'best_fitness' attribute to 'float32' before passing it to a JAX function that need essentials arguments to have the same data type.

I suggest you can change the fitness function and include this line calling the JAX function that use the same data type arguments:

-state.best_fitness = jnp.float32(state.best_fitness)

this will send the 'float32' attribute to 'float32' which should fix the type error when using JAX functions that require the same data types.

matthewrhysjones commented 1 year ago

Thanks for the suggestion! I initially tried this, but I believe the class attributes are frozen, returning the error message

FrozenInstanceError: cannot assign to field 'best_fitness'

RobertTLange commented 1 year ago

Hi there, thank you for the kind words and interest in evosax. You can overwrite flax struct-style objects using replace. E.g. state = state.replace(best_fitness=-jnp.finfo(jnp.float64).max). I hope this helps! Best wishes, Rob

matthewrhysjones commented 1 year ago

This worked perfectly, thanks!

carlosgmartin commented 1 year ago

I ran into this issue as well, except in my case the objective function returns integers. Simplified example:

import jax
import evosax

def f(x, key):
    return (x[0] + jax.random.normal(key)).astype(int)  # .astype(float)

key = jax.random.PRNGKey(0)
strategy = evosax.OpenES(popsize=2, num_dims=1)
state = strategy.initialize(key)
while True:
    print(state.mean)
    key, key_ask, key_eval = jax.random.split(key, 3)
    xs, state = strategy.ask(key_ask, state)
    ks = jax.random.split(key_eval, strategy.popsize)
    ys = jax.vmap(f)(xs, ks)
    state = strategy.tell(xs, ys, state)

@RobertTLange Would you consider replacing jax.lax.select with jax.numpy.where, like the error message suggests, to save users the trouble of manually casting to floats? Thanks!