Open eleninisioti opened 10 months ago
Not related to the question about cma-es but examples/07_brax_control.ipynb has another issue, it reshapes the observations to (1,87) while for the ant the observation size is 27. So don't know how it should ran without errors to begin with. I also had to remove the legacy_spring argument and saving rollouts to html does not work, getting an error that rollout does not have 'qp'. So it seems I am using a wrong version of brax and/or mujoco. Here's the code I am trying to run:
import numpy as np
from evojax.obs_norm import ObsNormalizer
from evojax.sim_mgr import SimManager
from evojax.task.brax_task import BraxTask
from evojax.policy import MLPPolicy
import os
from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper
def get_brax_task(
env_name = "ant",
hidden_dims = [32, 32, 32, 32],
):
train_task = BraxTask(env_name, test=False)
test_task = BraxTask(env_name, test=True)
policy = MLPPolicy(
input_dim=train_task.obs_shape[0],
output_dim=train_task.act_shape[0],
hidden_dims=hidden_dims,
)
return train_task, test_task, policy
train_task, test_task, policy = get_brax_task("ant")
solver = Evosax2JAX_Wrapper(
Strategies["CMA_ES"],
param_size=policy.num_params,
pop_size=256,
es_config={},
es_params={},
seed=0,
)
obs_normalizer = ObsNormalizer(
obs_shape=train_task.obs_shape, dummy=not True
)
sim_mgr = SimManager(
policy_net=policy,
train_vec_task=train_task,
valid_vec_task=test_task,
seed=0,
obs_normalizer=obs_normalizer,
pop_size=256,
use_for_loop=False,
n_repeats=16,
test_n_repeats=1,
n_evaluations=128
)
print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
for gen_counter in range(1):
params = solver.ask()
scores, _ = sim_mgr.eval_params(params=params, test=False)
solver.tell(fitness=scores)
if gen_counter == 0 or (gen_counter + 1) % 5 == 0:
test_scores, _ = sim_mgr.eval_params(
params=solver.best_params, test=True
)
print(
{
"num_gens": gen_counter + 1,
},
{
"train_perf": float(np.nanmean(scores)),
"test_perf": float(np.nanmean(test_scores)),
},
)
from brax import envs
from brax.io import html
import jax
env = envs.create(env_name="ant")
task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)
obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)
best_params = solver.best_params
obs_params = sim_mgr.obs_params
total_reward = 0
rollout = []
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)
while not task_state.done:
rollout.append(task_state)
task_state = task_state.replace(
obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(1, 27))
act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
task_state = task_state.replace(
obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(27,))
task_state = step_fn(task_state, act[0])
total_reward = total_reward + task_state.reward
print("Cumulative reward:", total_reward)
output = html.render(env.sys, [s.qp for s in rollout])
saving_directory = "projects"
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
with open(saving_directory + "/rollout_cmaes.html", "w") as f:
f.write(output)
Hi! I am interested in training the ant robot using the CMA-ES strategy.
I tried running the notebook examples/07_brax_control.ipynb and the performances I get are off. When running it as it is, with the OpenES strategy, after 100 generations I am at performance 400, while the notebook states 900. That's not a huge difference, it get's to a good performance by 1000 generations, but it may indicate that some library version is off?
When running the CMA_ES, having an empty es_config and es_params, my performance is decreasing instead of increasing. It starts off at -53 and ends up around -1000 after 1000 generations. Perhaps some config parameter is off but isn't it weird that it is decreasing?
I cloned the latest version from git and the following is my conda list (running on a gpu):