RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

CMA_ES for ant brax robot performs badly #64

Open eleninisioti opened 7 months ago

eleninisioti commented 7 months ago

Hi! I am interested in training the ant robot using the CMA-ES strategy.

I tried running the notebook examples/07_brax_control.ipynb and the performances I get are off. When running it as it is, with the OpenES strategy, after 100 generations I am at performance 400, while the notebook states 900. That's not a huge difference, it get's to a good performance by 1000 generations, but it may indicate that some library version is off?

When running the CMA_ES, having an empty es_config and es_params, my performance is decreasing instead of increasing. It starts off at -53 and ends up around -1000 after 1000 generations. Perhaps some config parameter is off but isn't it weird that it is decreasing?

I cloned the latest version from git and the following is my conda list (running on a gpu):

_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
absl-py                   2.0.0                    pypi_0    pypi
blinker                   1.7.0                    pypi_0    pypi
brax                      0.9.3                    pypi_0    pypi
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2023.12.12           h06a4308_0
chex                      0.1.85                   pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
cloudpickle               3.0.0                    pypi_0    pypi
cma                       3.3.0                    pypi_0    pypi
contextlib2               21.6.0                   pypi_0    pypi
contourpy                 1.2.0                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
dm-env                    1.6                      pypi_0    pypi
dm-tree                   0.1.8                    pypi_0    pypi
etils                     1.6.0                    pypi_0    pypi
evojax                    0.2.16                   pypi_0    pypi
evosax                    0.1.5                    pypi_0    pypi
flask                     3.0.0                    pypi_0    pypi
flask-cors                4.0.0                    pypi_0    pypi
flax                      0.6.11                   pypi_0    pypi
fonttools                 4.47.0                   pypi_0    pypi
fsspec                    2023.12.2                pypi_0    pypi
glfw                      2.6.4                    pypi_0    pypi
grpcio                    1.60.0                   pypi_0    pypi
gym                       0.26.2                   pypi_0    pypi
gym-notices               0.0.8                    pypi_0    pypi
importlib-resources       6.1.1                    pypi_0    pypi
itsdangerous              2.1.2                    pypi_0    pypi
jax                       0.4.23                   pypi_0    pypi
jaxlib                    0.4.23+cuda11.cudnn86          pypi_0    pypi
jaxopt                    0.8.2                    pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1
libffi                    3.4.4                h6a678d5_0
libgcc-ng                 11.2.0               h1234567_1
libgomp                   11.2.0               h1234567_1
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
matplotlib                3.8.2                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
ml-collections            0.1.1                    pypi_0    pypi
ml-dtypes                 0.3.2                    pypi_0    pypi
msgpack                   1.0.7                    pypi_0    pypi
mujoco                    2.3.7                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
nest-asyncio              1.5.8                    pypi_0    pypi
numpy                     1.26.3                   pypi_0    pypi
nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
nvidia-cublas-cu12        12.3.4.1                 pypi_0    pypi
nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
nvidia-cuda-cupti-cu12    12.3.101                 pypi_0    pypi
nvidia-cuda-nvcc-cu11     11.8.89                  pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.3.107                 pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.3.107                 pypi_0    pypi
nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
nvidia-cuda-runtime-cu12  12.3.101                 pypi_0    pypi
nvidia-cudnn-cu11         8.9.6.50                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.7.29                 pypi_0    pypi
nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
nvidia-cufft-cu12         11.0.12.1                pypi_0    pypi
nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
nvidia-cusolver-cu12      11.5.4.101               pypi_0    pypi
nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
nvidia-cusparse-cu12      12.2.0.103               pypi_0    pypi
nvidia-nccl-cu11          2.19.3                   pypi_0    pypi
nvidia-nccl-cu12          2.19.3                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.3.101                 pypi_0    pypi
openssl                   3.0.12               h7f8727e_0
opt-einsum                3.3.0                    pypi_0    pypi
optax                     0.1.7                    pypi_0    pypi
orbax-checkpoint          0.4.8                    pypi_0    pypi
packaging                 23.2                     pypi_0    pypi
pillow                    10.2.0                   pypi_0    pypi
pip                       23.3.1                   pypi_0    pypi
protobuf                  4.25.1                   pypi_0    pypi
pygments                  2.17.2                   pypi_0    pypi
pyopengl                  3.1.7                    pypi_0    pypi
pyparsing                 3.1.1                    pypi_0    pypi
python                    3.11.7               h955ad1f_0
python-dateutil           2.8.2                    pypi_0    pypi
pytinyrenderer            0.0.14                   pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0
rich                      13.7.0                   pypi_0    pypi
scipy                     1.11.4                   pypi_0    pypi
setuptools                68.2.2                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0
tensorboardx              2.6.2.2                  pypi_0    pypi
tensorstore               0.1.52                   pypi_0    pypi
tk                        8.6.12               h1ccaba5_0
toolz                     0.12.0                   pypi_0    pypi
trimesh                   4.0.8                    pypi_0    pypi
typing-extensions         4.9.0                    pypi_0    pypi
tzdata                    2023d                h04d1e81_0
werkzeug                  3.0.1                    pypi_0    pypi
wheel                     0.41.2                   pypi_0    pypi
xz                        5.4.5                h5eee18b_0
zipp                      3.17.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_0
eleninisioti commented 7 months ago

Not related to the question about cma-es but examples/07_brax_control.ipynb has another issue, it reshapes the observations to (1,87) while for the ant the observation size is 27. So don't know how it should ran without errors to begin with. I also had to remove the legacy_spring argument and saving rollouts to html does not work, getting an error that rollout does not have 'qp'. So it seems I am using a wrong version of brax and/or mujoco. Here's the code I am trying to run:


import numpy as np
from evojax.obs_norm import ObsNormalizer
from evojax.sim_mgr import SimManager
from evojax.task.brax_task import BraxTask
from evojax.policy import MLPPolicy
import os
from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper

def get_brax_task(
    env_name = "ant",
    hidden_dims = [32, 32, 32, 32],
):
    train_task = BraxTask(env_name, test=False)
    test_task = BraxTask(env_name, test=True)
    policy = MLPPolicy(
        input_dim=train_task.obs_shape[0],
        output_dim=train_task.act_shape[0],
        hidden_dims=hidden_dims,
    )
    return train_task, test_task, policy

train_task, test_task, policy = get_brax_task("ant")
solver = Evosax2JAX_Wrapper(
    Strategies["CMA_ES"],
    param_size=policy.num_params,
    pop_size=256,
    es_config={},
    es_params={},
    seed=0,
)
obs_normalizer = ObsNormalizer(
    obs_shape=train_task.obs_shape, dummy=not True
)
sim_mgr = SimManager(
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=0,
    obs_normalizer=obs_normalizer,
    pop_size=256,
    use_for_loop=False,
    n_repeats=16,
    test_n_repeats=1,
    n_evaluations=128
)

print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
for gen_counter in range(1):
    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)
    if gen_counter == 0 or (gen_counter + 1) % 5 == 0:
        test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
        print(
            {
                "num_gens": gen_counter + 1,
            },
            {
                "train_perf": float(np.nanmean(scores)),
                "test_perf": float(np.nanmean(test_scores)),
            },
        )

from brax import envs
from brax.io import html
import jax

env = envs.create(env_name="ant")
task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)
obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)

best_params = solver.best_params
obs_params = sim_mgr.obs_params

total_reward = 0
rollout = []
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)
while not task_state.done:
    rollout.append(task_state)
    task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(1, 27))
    act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
    task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(27,))
    task_state = step_fn(task_state, act[0])
    total_reward = total_reward + task_state.reward

print("Cumulative reward:", total_reward)
output = html.render(env.sys, [s.qp for s in rollout])
saving_directory = "projects"
if not os.path.exists(saving_directory):
    os.makedirs(saving_directory)

with open(saving_directory + "/rollout_cmaes.html", "w") as f:
    f.write(output)