google / evojax

Apache License 2.0
826 stars 78 forks source link

Add a Python/JAX port of CR-FM-NES #46

Closed dietmarwo closed 1 year ago

dietmarwo commented 1 year ago

This PR adds a Python/JAX port of Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES), see https://arxiv.org/abs/2201.11422 . Derived from https://github.com/nomuramasahir0/crfmnes.

This variant is slightly faster than FCRFMC (the C++ port) on fast GPUs/TPUs, but slower on CPUs and for smaller dimensions. It uses 32 bit accuracy (FCRFMC uses 64 bit) which mostly doesn't harm the convergence (with Waterworld MA being the exception for very high iteration numbers).

Wall time and convergence is mostly comparable with PGPE (as FCRFMC) for the benchmarks. Slower in the beginning, but improving at higher iterations.

Since there are no for-loops I found no beneficial applications of 'jax.jit', just converted most 'np.arrays' into 'jnp.arrays' deployed on the GPUs/TPUs.

def sort_indices_by(evals: np.ndarray, z: jnp.ndarray) -> jnp.ndarray:

uses not evals: jnp.ndarray because this slowed things down on my NVIDIA 3090.

Since this is Python code, no missing shared libraries on Ubuntu 18 this time.

Added test results for CRFMNES (this Python implementation) at EvoJax.adoc.

lerrytang commented 1 year ago

Thanks for the contribution! Together with FCRFMC, will you send a pr to update this page?

dietmarwo commented 1 year ago

On vacation until end of next week. Will update the page when I am back.