RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
479 stars 44 forks source link

Control_Brax: HTML output for debugging & Obs Normalization #16

Closed donthomasitos closed 1 year ago

donthomasitos commented 2 years ago

Great libary, thank you for your work!

I want to add the HTML output for Brax and stumbled across a problem: For Brax's built in HTML output, it needs a list of env_state.qp. This can be collected at test time like:

qps = []
for i in range(config["num_env_steps"]):
    qps.append(env_state.qp)
    rng_net, rng = jax.random.split(rng)
    #norm_obs = evaluator.obs_normalizer.normalize_obs(env_state.obs, evaluator.obs_params)  <- problem is here
    act = network.apply({"params": best_params}, env_state.obs, rng=rng_net)
    env_state = jit_env_step(env_state, act)
html.save_html("output.html", env.sys, qps)

But I can't access the obs_normalizer from the evaluator, as JAX complains that data leaks a JIT'ed function. I wonder if I misunderstand the architecture - do you have a recommendation on how to implement this output? It's no problem if the normalize_obs is simply disabled (hence the line stays commented), but I experienced it to be beneficial in many scenarios.

RobertTLange commented 2 years ago

Hi Thomas, thank you for the kind words and bringing this up. The Brax rollout wrapper is still very much "under construction". I am currently battling with the NeurIPS deadline, but will put together a new release with better documentation once I am done with that. I also believe that there might be something wrong with the obs normalization. I compared it with evojax's and on the ant task they start to give different performances after some generations. Will come back to you once I find the time. Best, Rob

RobertTLange commented 1 year ago

Fixed in PR #34 see new brax notebook