RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
514 stars 43 forks source link

Issue with `float64` precision #67

Open bheijden opened 9 months ago

bheijden commented 9 months ago

Hi,

Great work on the toolkit!

I encountered a problem while employing the CMA-ES algorithm. Specifically, one parameter had a standard deviation of approximately 1e-6, resulting in a variance of 1e-12. This level of precision exceeded the capabilities of the float32 data type, rendering it inadequate for representing the covariance matrix accurately. Consequently, this limitation caused the generation of samples with zero variance. Although rescaling the covariance matrix was a potential solution, I opted to implement the algorithm using float64 precision as a preliminary measure with:

# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

However, this approach led to a type mismatch error between float32 and float64, caused by a specific value set as float32 data type, as indicated here:

  File "/home/r2ci/rex/sysid/evo.py", line 118, in evo_step
    new_state = solver.strategy.tell(x, loss_nonan, state, solver.strategy_params)
  File "/home/r2ci/evosax/evosax/strategy.py", line 140, in tell
    best_member, best_fitness = get_best_fitness_member(
  File "/home/r2ci/evosax/evosax/utils/helpers.py", line 22, in get_best_fitness_member
    best_fitness = jax.lax.select(
TypeError: lax.select requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

The fix below to this line resolves the problem. Other strategies may also be affected.

@struct.dataclass
class EvoState:
    p_sigma: chex.Array
    p_c: chex.Array
    C: chex.Array
    D: Optional[chex.Array]
    B: Optional[chex.Array]
    mean: chex.Array
    sigma: float
    weights: chex.Array
    weights_truncated: chex.Array
    best_member: chex.Array
   # Converts to float, and back to jax.Array so that it is correctly configured as float64.
    best_fitness: float = jnp.array(float(jnp.finfo(jnp.float32).max))  
    # best_fitness: float = jnp.finfo(jnp.float32).max  # old
    gen_counter: int = 0

MWE:

import jax
import jax.numpy as jnp

# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

from evosax import CMA_ES
from evosax.problems import BBOBFitness

# Instantiate the evolution strategy instance
strategy = CMA_ES(num_dims=2, popsize=10)

# Get default hyperparameters (e.g. lrate, etc.)
es_params = strategy.default_params
es_params = es_params.replace(init_min= -3, init_max=3)

# Initialize the strategy
rng = jax.random.PRNGKey(0)
state = strategy.initialize(rng, es_params)

# Instantiate helper class for classic evolution strategies benchmarks
evaluator = BBOBFitness("RosenbrockOriginal", num_dims=2)
# Ask for a set of candidate solutions to evaluate
x, state = strategy.ask(rng, state, es_params)
# Evaluate the population members
fitness = evaluator.rollout(rng, x)
# Update the evolution strategy
state = strategy.tell(x, fitness, state, es_params)
state
bheijden commented 9 months ago

Closing because duplicate of #45.

bheijden commented 9 months ago

On second thought, I do re-open this issue, because the solution I propose could potentially be a permanent fix that does not require extra work from the user.

Feel free to close if you think it's not worth the effort!