google / evojax

Apache License 2.0
834 stars 85 forks source link

Evosax - Augmented Random Search #9

Closed RobertTLange closed 2 years ago

RobertTLange commented 2 years ago

Hi there 🤗 - I am opening a first PR, which adds Augmented Random Search (Mania et al., 2018) to evojax. The implementation wraps evosax's.

There are a couple general considerations:

I have also created a mini-repository for running the benchmarks, storing logs and configuration files. Maybe this can be of general interest? I am planning to add a mle-hyperopt parameter search pipeline. You can find the ARS logs here and this is the benchmarking summary for ARS:

Benchmarks Parameters Results
CartPole (easy) 900 (max_iter=1000) Link 902.107
CartPole (hard) 600 (max_iter=1000) Link 666.6442
Waterworld 6 (max_iter=500) Link 6.1300
Waterworld (MA) 2 (max_iter=2000) Link 1.4831
Brax Ant 3000 (max_iter=300) Link 3298.9746
MNIST 90.0 (max_iter=2000) Link 0.9610

Update: I added hyperparameter search utilities and coarsely grid searched the initiatl learning rate and standard deviation. Here are some results for the cartpole and mnist taks:

Cartpole-Easy Cartpole-Hard MNIST
alantian commented 2 years ago

Hey @RobertTLange , Thanks for the detailed PR!

I have merged it with some quick adds:

BTW, It seems that evosax may have issues under Python 3.6, which you may want to have a look at (although it does not affect the rest of EvoJAX for now.) Please refer to

  1. The output of CI workflow https://github.com/google/evojax/runs/5268178005 , and
  2. The output of my local smoke testing:
Python 3.6.13 |Anaconda, Inc.| (default, Feb 23 2021, 12:58:59)
[GCC Clang 10.0.0 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import evosax
>>> from evosax import Augmented_RS, FitnessShaper
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'Augmented_RS'
>>>