RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

fix top_k log #23

Closed newtonkwan closed 2 years ago

newtonkwan commented 2 years ago

Calculation of the top k agents using ESLog is incorrect.

Expectation:

    es_logging = ESLog(num_dims=2, num_generations=10, top_k=3, maximize=True)
    log = es_logging.initialize()
    x = jnp.array([[1, 2], [2, 4], [4, 6], [6, 7]])
    fitness = jnp.array([1, 2, 3, 4])
    assert jnp.array_equal(log["top_fitness"], jnp.array([4, 3, 2]))
    # Pass: [4. 3. 2.] 

Reality:

    es_logging = ESLog(num_dims=2, num_generations=10, top_k=3, maximize=True)
    log = es_logging.initialize()
    x = jnp.array([[1, 2], [2, 4], [4, 6], [6, 7]])
    fitness = jnp.array([1, 2, 3, 4])
    assert jnp.array_equal(log["top_fitness"], jnp.array([4, 3, 2]))
    # AssertionError: [4. 4. 4.]

Fix for Issue #24

L:58 of es_logger.py: the second argsort() in top_idx is one parenthesis off.