google / evojax

Apache License 2.0
826 stars 78 forks source link

add CR-FM-NES algorithm #44

Closed dietmarwo closed 1 year ago

dietmarwo commented 1 year ago

Adds a wrapper to CR-FM-NES, see "Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES)" pdf .

It wraps the fcmaes Eigen/C++ version of CR-FM-NES which is derived from https://github.com/nomuramasahir0/crfmnes.

Since there are numpy and Eigen based implementations (and soon a JAX based one) of CR-FM-NES available, it will be possible to compare the performance of these tree "backends" for the same algorithm. This commit wraps only the C++/Eigen based implementation crfmnes.cpp .

Tested on NVIDIA 3090 + AMD 5950x Linux Mint 20 (Ubuntu based). Performance (wall time) is similar to PGPE outperforming CMA_ES_JAX. Benchmark results for waterworld are above all other algorithms. Do "pip install fcmaes --upgrade" before testing.

google-cla[bot] commented 1 year ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

dietmarwo commented 1 year ago

Have question regarding train_ant_map_elites.py : The maximal value I observe using the default iterations and MAPElites is about 710. Increasing the number of operations doesn't improve the result very much (781.887 after 1900 * 1024 evaluations). But using the new solver proposed in this PR:

solver = FCRFMC(
    pop_size=512,
    param_size=policy.num_params,
    init_stdev=0.07,
    seed=config.seed,
    logger=logger,
    )

the score goes steadily up to 3922 after 15.000 512 evaluations. Why is such a low iteration number configured and which value are we aiming here? Is MAPElites the best solver for this problem? FCRFMC reaches > 3000 using 19001024 evaluations and > 2600 using the same wall time (5600 sec on AMD 5950x + NVidia 3090 on linux).

Further tests revealed that PGPE also outperforms MAPElites significantly, specially for lower iteration numbers:

    solver = PGPE(
        pop_size=512,
        param_size=policy.num_params,
        optimizer='adam',
        center_learning_rate=0.014,
        stdev_learning_rate=0.088,
        init_stdev=0.069,
        logger=logger,
        seed=config.seed,
    )

At higher iterations (> 10000) FCRFMC is better than PGPE where the gap is widening for increasing iterations. This may indicate application scenarios for Fast Moving Natural Evolution Strategy: High number of parameters (> 5000) and high iteration number, were often a smaller population compared to other algorithms is sufficient. Would be interesting to compare wall time on different TPUs and GPUs for the Eigen implementation with a future jax based one.

lerrytang commented 1 year ago

Thanks for sending the PR! This is to answer your question about map-elites.

The map-elites example is to demonstrate the interfaces for implementing a quality diversity (QD) method. I used a simple GA algorithm for optimization and didn't tune hyper-parameters. While other algorithms may give higher scores, QD methods output an archive of policies, each of which behaves differently based on the pre-defined behavior descriptors. Therefore QD methods are interesting in their own rights.

I'll merge the PR once it passes my local tests.

lerrytang commented 1 year ago

Hi, I've tried to run the code locally and I'm getting this error OSError: /lib/x86_64-linux-gnu/libm.so.6: version "GLIBC_2.29" not found from fcmaes. Can you give a description of your runtime, and a more detailed instruction to install fcmaes? (Apparently, pip install is not sufficient)

dietmarwo commented 1 year ago

See fast-cma-es section "Installation"

"To use the C++ optimizers a gcc-9.3 (or newer) runtime is required. This is the default on newer Linux versions. If you are on an old Linux distribution you need to install gcc-9 or a newer version. On ubuntu this is:

sudo add-apt-repository ppa:ubuntu-toolchain-r/test

sudo apt update

sudo apt install gcc-9"

Or how-install-glibc-2-29-or-higher-in-ubuntu-18-04

Didn't mention this since gcc9 is standard for all actual linux distributions for some time.

Here is a detailed performance comparison for the NVIDIA 3090. EvoJax.adoc And here a Python implementation providing ask/tell crfmnes.py which is about 4 times slower than the C++ version, but could be used as basis for a jax port. Tried one myself, but could not reach the C++ performance, may be someone with more experience with jax is more successful.