google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.22k stars 244 forks source link

Replacing gym's Mujoco envs with brax envs #49

Open vwxyzjn opened 3 years ago

vwxyzjn commented 3 years ago

Had a conversation with @jkterry1 on https://github.com/openai/gym/issues/2366, and it appears brax would also be a great alternative for the mujoco envs replacement.

To help with this transition. I made an attempt to try out brax with pytorch. Here is a basic report: https://wandb.ai/costa-huang/brax/reports/Brax-as-Pybullet-replacement--Vmlldzo5ODI4MDk. The source code is here: https://github.com/vwxyzjn/cleanrl/blob/mybranch/cleanrl/brax/readme.md

One of the biggest issue with the brax adoption is the env normalization:

I think going forward, probably the best way to fix this is to refactor the brax training side's normalization to the environment side. This in the future will also help throughput with the JaxToTorchWrapper. Otherwise, the observation will go from GPU to CPU for gym or sb3's normalization wrapper, then GPU again for torch, which just doesn't make sense.

One small thing is that given the brax environment directly produces the vector env, there is also no way to inject a ClipActionsWrapper(env), which may or may not have a performance impact. That said, this can be implemented in the training side with ease.

erwincoumans commented 3 years ago

Yes, as I suggested previously, Brax seems a good option for OpenAI Gym, since it allows for GPU and TPU accelerators (training in minutes instead of hours), next to CPU. We can use this issue to track progress and add an itemized todo.

jkterry1 commented 3 years ago

To recap the to do list:

  1. Add suitable rendering
  2. Further tune observation/action spaces to make them as close as possible
  3. Make sure we are not reproducing the list of bugs in MuJuCo environments from Antonin Raffin that I sent you

I feel like there may have been a 4th issue, but I don't sleep very much and can no longer recall it. @erwincoumans @benelot do you remember?

vwxyzjn commented 3 years ago

One note on the suitable rendering is I feel implementing env.render(“rgb_array”) might be too expensive and counterproductive. Maybe implementing env.render(“html”) at the end of episode is more preferable.

jkterry1 commented 3 years ago

They're planning to add a new rendering engine such that "rgb_array" will be suitable

jkterry1 commented 3 years ago

I don't know if this is the 4th feature I can't remember, but another thing we'll need to eventually deal with that I briefly discussed is action/observation space documentation for the new Gym website we're working on, in the flavor of https://www.pettingzoo.ml/classic/chess

joaogui1 commented 3 years ago

I would like to help with this, what can I do to help?

jkterry1 commented 3 years ago

@joaogui1 Probably nothing, at least at the moment. Right I'm waiting on the Brax team to do some work and for the guy who created the pybullet replacement envs to get back from vacation, this will take 4-6 weeks. If you'd like to help with gym maintenance problems in general though, please email me and we can coordinate some things (jkterry@umd.edu)

joaogui1 commented 3 years ago

Got it, will wait a little then, thanks!

sgillen commented 3 years ago

I'm also happy to help on this, I've spent a lot of time with the mujoco/pybullet environments at this point. Can certainly help with points 2/3 that @jkterry1 posted in this thread.

erikfrey commented 3 years ago

We have started working on 1) the renderer. We're looking at porting a simple technique like https://github.com/rougier/tiny-renderer to JAX as a new module in brax.io

Tuning observation/action space could start in parallel if anyone is interested. I think the steps would involve:

1) reset a Gym Mujoco env (say Ant) to default state and inspect the observation space and its description 2) compare to Brax Ant env and make adjustments 3) step both and compare dynamic observations (e.g. contact forces)

I think the envs are already ~80% comparable, and the last 20% is just sleuthing to read the mujoco docs, and confirm the format matches. I think we can get to the point where the meaning of each observation dimension is the same in both envs, even if the dynamics are still different.

sgillen commented 3 years ago

I can get that going next week. I will use Mujoco 1.5 due to this issue. It looks like the Brax environments are based off the v2 version of the Mujoco environments, so I'll start by comparing to those. Based on https://github.com/openai/gym/pull/1304 I think the v3 versions are supposed to be identical if using default args, not 100% sure that's true though.

vwxyzjn commented 3 years ago

This is so great to hear! I also have a quick update. Gym now has a normalization wrapper: https://github.com/openai/gym/pull/2387. The usage is roughly

env = gym.make("HalfCheetahBulletEnv-v0")
env = gym.wrappers.RecordEpisodeStatistics(env)
env = gym.wrappers.ClipAction(env)
env = gym.wrappers.NormalizeObservation(env)
env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
env = gym.wrappers.NormalizeReward(env)
env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))

However as I suggested earlier, this might be not as fast as implementing the normalization on brax's side. Another thing is directly applying these wrappers to brax environment won't work because some issues with jax's device array overriding numpy arrays in the wrappers.

A typical example is gym.wrappers.RecordEpisodeStatistics, and its episode_returns array will be casted to a jax array, which causes problems because jax array is not mutable.

sgillen commented 2 years ago

Ok, I was a bit busier than I expected this week, but as promised I did start comparing the ant environments this evening. Here is a notebook I was using that may be useful to anyone else who wants to compare and tweak the envs.

With regards to the observations:

  1. I believe all the state position and velocity information match up. For Mujoco it seems to be: z + quaternion for the torso (5), 8 joint angles, dxyz/drot (6) for torso, 8 more joint velocities, which matches exactly what brax has.
  2. The contact information is where big differences appear. Brax seems to be missing some internal bodies that are present in the Mujoco model, this accounts for the difference in observation size (The brax team was already aware of this).
  3. I'm not sure what the ordering for the contact forces is in brax. It doesn't match what mujoco does (see the notebook linked above) and it also doesn't seem to match up with the bodies in env.sys.body_idx.keys().

with regard to rewards:

  1. The rewards also exclude any contact force penalty because of the lack of those forces caused by a bug with gym+Mujoco 2.0 (see the issue I posted above), but I think it would be best to put them back.

If the goal is to make as faithful representation of mujoco envs as possible (which IMO it shouldn't necessarily be) then we will at least need to address the following:

  1. The mj ant starts life suspended .75m in the air, the brax ant at .5
  2. mj adds a relatively large amount random noise to its initial state on reset.
  3. Inertial parameters for the two envs are different. Does brax have a way to infer an inertia from geometry? This is what mj does.
  4. No matter the ordering , the magnitude of the force and and moment are substantially different, but that may be because of the difference in mass.
  5. Torque limits appear different 300 in brax vs 150 in mj (units? That would be a lot of N*m)
  6. These are minor, but may want to find out what brax integrator settings are closest to an rk4 with dt = .01.
  7. May also want to tune friction parameters, which will probably need to be done empirically.

TLDR: For the ant the difference in observations is in the ordering and number of contact forces. To make them match exactly we would need to re order the existing forces, and insert some dummy, zeroed elements into the observation. That said the "missing" contact forces weren't useful in the old env, and the ordering of contacts shouldn't matter to an RL agent, so IMHO it would be enough to adjust the mass, inertia, and torque limit, add back in the contact force reward/penalty, and maybe add the wider distribution to initial state.

sgillen commented 2 years ago

@vwxyzjn good to hear about the normalization wrapper, I agree that the normalization and clipping should all be done on the brax side. This makes things awkward with respect to saving and loading environments / agents, since it will make brax a special case for gym, sb3 etc. Related, I also think that if the brax envs aren't going to be extremely fast that it would better to just use pybullet.

erikfrey commented 2 years ago

@vwxyzjn we recently started using a similar Wrapper concept for wrapping envs in Brax, inspired by Gym. e.g. EpisodeWrapper collects episode statistics and sets done at the episode boundary, and so on:

https://github.com/google/brax/blob/main/brax/envs/wrappers.py#L43

I don't think it would be too hard to make the brax API mirror what gym is doing, and still keep it all on device.

erikfrey commented 2 years ago

@sgillen this is super helpful - thanks for putting together this thorough comparison. I hear you that our envs don't need to be exactly 1:1 to MuJoCo's - that said, we'd be happy to prioritize any fixes to the differences you brought up, according to whether they:

Of the differences you found, do you have a suggestion for which might be the most important to address?

benelot commented 2 years ago

I agree with @sgillen on the tasks, but would reorder to:

  1. add back in the contact force reward/penalty
  2. adjust the mass, inertia, and torque limit
  3. add the wider distribution to initial state

On 1: If we want to copy the previous env, we need it, whether it helps with training or not, otherwise we diverge. On 2: Is there any reason these were set in the brax ant env the way they are? Torque limit looks like the result of f(mass, inertia, mujoco_engine_details), so we should be able to set similar ones to mujoco. If they can not be adjusted exactly, I would suggest to fall back to the metric of "similar learning curve". In pybullet I once looked at the metric of "similar observation distribution shape" which says something about in which observational manifold the ant moves. On 3: This is certainly important for higher robustness of the learned policy. Especially in the humanoids, adding some noise during testing but not training easily messes them up.

On my side I started to play a bit with brax and built some initial version of the humanoid standup but ,being on vacation, I am not done yet. I plan to begin building a first version of all required mujoco envs next week in brax just to see how they perform. Then we can do the same for every env as @sgillen did for ant.

jkterry1 commented 2 years ago

Just to confirm, does the list of inconsistencies include the list of bugs in MuJuCo that we want to make sure that we aren't reproducing that I sent?

sgillen commented 2 years ago

@erikfrey I agree with @benelot list on what to prioritize. They will probably impact training, making the environment slightly harder if anything, but also closer to the original. The contact reward might lead to more pleasing gaits but it's hard to say.

@jkterry1 I am not sure, can you post that list of bugs here?

benelot commented 2 years ago

@jkterry1 possibly means those: (according to Antonin Raffin)

jkterry1 commented 2 years ago

@benelot that's the list, thanks a ton

erikfrey commented 2 years ago

Can confirm that our HalfCheetah is at least not broken in the ways discussed in those blogs. In fact this is something we had to address in our paper comparing our envs to Mujoco's. See section E1 in the appendix for a brief discussion about this problem.

That said, I am quite prepared for folks to find new and interesting bugs as these envs get more attention! We'll be happy to address them when they come up :-)

We are 90% done on hopper. If someone would like to take a pass at Walker2d or Swimmer, please let me know. Otherwise we'll get to them soon.

erikfrey commented 2 years ago

Quick update - we now have the Hopper env, and tomorrow we will land Walker2d. We'll also add them soon to the colab with good default hparams. Other things in flight:

erikfrey commented 2 years ago

OK! We now support state to pixels for env.render:

https://github.com/google/brax/blob/main/brax/io/image.py

Please keep in mind this is CPU rendering, so better for eval rendering and other programmatic use cases, rather than training. We will move to GPU/TPU rendering in the future, which should be suitable for training.

In the coming days we'll update our colabs with an example of how to use it.

slerman12 commented 2 years ago

I'm trying to making Brax/MuJoCo more apples-to-apples in the setup for them. I'm not sure what major differences need to be accounted for. Is there a set of operations that need to be called on Brax to get settings as similar to MuJoCo as possible? (e.g. this normalization mentioned in this issue here)

sgillen commented 2 years ago

Hi @slerman12, the process is still ongoing I think to make the brax environments similar to Mujoco. This thread has some info on the major differences at this point, you can see the notebook I posted above as a starting point for comparing the environments in an "apples-to-apples" way. The normalization is not a difference by itself, the Mujoco envs don't have normalization built in. Usually training frameworks like stable baselines will normalize observations from environments, but that presents some difficulty in brax.

jkterry1 commented 2 years ago

Per the meeting, we still need the following things before merging into Gym:

Adding missing environments: Swimmer (Benjamin Ellenberger) Standup (Brax team) Inverted pendulum (Daniel Freeman) Inverted double pendulum (Daniel Freeman)

Remove 0s where applicable (Brax team) Remove unnecessary inheritance regarding hopper (Brax team)

benelot commented 2 years ago

I have not found pusher, reacher, striker, thrower anywhere in the brax repo. I think they are required as well @jkterry1. Are they somewhere internal @cdfreeman-google?

erikfrey commented 2 years ago

Reacher is here: https://github.com/google/brax/blob/main/brax/envs/reacher.py

Ah, I wasn't aware of pusher, striker, thrower as they are not here: https://gym.openai.com/envs/#mujoco

BUT I do see them here: https://github.com/openai/gym/tree/master/gym/envs/mujoco

We'll look into those on the Brax side unless anyone jumps in and would like to claim them.

erikfrey commented 2 years ago

OK more updates:

jkterry1 commented 2 years ago

I wasn't planning to include pusher, striker and thrower in the new version of the environments. No one really uses them, they weren't even in the PPO paper. If for some reason during the sort of beta period issues arise and adding them does prove to be important for some reason, who can of course do it then.

erikfrey commented 2 years ago

Got it, OK. We'll leave those out for now, then. Just pushed Humanoid standup in humanoid_standup.py. It's a real "stand up" environment:

standup

Inverted pendulums coming soon.

vwxyzjn commented 2 years ago

These look great! I’d love to follow up with making a viable training example with brax using PyTorch PPO. It seems to me that the blockers would be the normalization wrapper and the episode stats wrapper. Specifically, I am looking for the replacement of these wrappers in the brax side

        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
        env = gym.wrappers.NormalizeReward(env)
        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))

RecordEpisodeStatistics doesn't work well with brax because brax overrides the default numpy array used to keep track of episode stats. The RecordVideo should work with brax given the support of env.render(mode='rgb_array'). The rest of the wrappers don't work well because because 1) brax would override some numpy array in these wrappers with Jax arrays like in RecordEpisodeStatistics, and 2) if using accelerators, these wrappers will considerably slow down the throughput: the observation would have to be transferred back to CPU to use these wrappers like NormalizeObservation, then transferred back to GPU for PyTorch inference, back to CPU again for ClipAction, then to GPU again to the brax engine, which is a very time-consuming and convolute process.

erikfrey commented 2 years ago

@vwxyzjn good suggestion! These wrappers all look easy to make jax versions of. It will be a lot of copy/paste... so it's tempting to think of some unified wrapper setup that can equally handle numpy arrays, jax arrays, or torch tensors, but that's probably an exercise left for some other time.

@vwxyzjn what would you propose is a good test PPO to demonstrate this all working end-to-end within? stable-baselines or something?

sgillen commented 2 years ago

Just to chime in, I think use stable-baselines3 will be extremely slow using brax without making an sb3 compatible brax vector environment first (in addition to the wrappers). I think that will be fairly straightforward though, given that brax already provides it's own vectorized environment.

There may be better ways to test the wrappers, but in either case I am planning to try to get good performance in brax from stable-baselines (without any wrappers) this evening! I think it would be interesting and useful to compare the existing pre-trained mujoco/bullet zoo agents to brax environments trained with the same hyper params.

erikfrey commented 2 years ago

@sgillen OK, please let us know how that works for you. We can probably jump in to help in the next day or two... and if you provide an end-to-end workflow that sorta works but is, say, missing wrappers, then that gives us a really concrete context within which to make the right changes.

Thank you!

vwxyzjn commented 2 years ago

@erikfrey thanks for the reply. I am testing out using my own library CleanRL, and its implementation of PPO is a bit simplified but matches the implementation details of SB3's. See 11 Core Implementation Details of PPO, 6 Implementation Details for Continuous Actions of PPO (draft) for a detailed walkthrough. Notably, my PPO implementation uses Gym's vector env, which might make the experiment easier since brax already has a gym vector wrapper.

sgillen commented 2 years ago

Ok, I've made the brax sb3 vec_env, and it does increase the performance, but the assumption that environments return numpy arrays is more baked into stable-baselines3 than I had hoped. Wrappers or no, all observations need to be copied to numpy, and therefore the cpu, before being used for training (even though the training itself is in torch). This leads to some painfully slow training times, at least compared to the millions of fps that you can get from the native brax algorithms.

You can see what I'm talking about here. With a batch size of 4096 and a GPU I see ~45,000 FPS. For comparison, I'm seeing ~50fps for the single threaded brax ant environment on cpu, and ~400fps for the bullet ant, I think a lot of training in sb3 is actually single threaded, so that 400FPS is probably around what a lot of people who use SB3 are used to. Either way, the vector env is an improvement but still clearly leaving a lot of performance on the table. Please let me know if you see something I have missed that can improve the performance I am seeing there.

It looks like @vwxyzjn's cleanRL might be torch end to end? I am curious what sorts of speeds you are seeing with the JaxToTorch wrapper (without the clipping / normalizing wrappers)?

vwxyzjn commented 2 years ago

I'd love to try it out but for some reasons every time I run the colab notebook (the pytorch example) I had this error

image

I have made an example notebook here: https://colab.research.google.com/drive/10Ud8dRpdgiYjNSDYNvsKDCmpGelgqD4k?usp=sharing

it also stucks on gym.reset()

Here is the code change modified from the PPO implementation that worked with MuJoCo and Pybullet. https://www.diffchecker.com/QKd2BjVB

vwxyzjn commented 2 years ago

I had a hard time running the colab notebook, so I decide to run it locally as follows:

git clone https://github.com/vwxyzjn/cleanrl.git
git checkout -b refactor brax
poetry install
poetry install -E brax
poetry run pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
cd cleanrl/brax
XLA_PYTHON_CLIENT_PREALLOCATE=false poetry run python ppo_brax_througput.py

The results are interesting where I would get increasingly higher FPS. Also, the first environment reset was very very slow, and maybe it was because of the warning below "Very slow compile?"

(cleanrl-0hpcRfYV-py3.9) ➜  brax git:(refactor-brax) ✗ XLA_PYTHON_CLIENT_PREALLOCATE=false python ppo_brax_througput.py
2021-10-27 10:38:56.231309: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module jit_reset.14286
********************************
SPS: 5
SPS: 11
SPS: 17
SPS: 23
SPS: 29
SPS: 34
..
SPS: 135
SPS: 141
SPS: 146
..
SPS: 1460
SPS: 1462

And if I don't use the GPU, it seems much faster

git clone https://github.com/vwxyzjn/cleanrl.git
git checkout -b refactor brax
poetry install
poetry install -E brax
cd cleanrl/brax
poetry run python ppo_brax_througput.py
(brax) ➜  brax git:(refactor-brax) python ppo_brax_througput.py --cuda False
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
SPS: 158
SPS: 310
SPS: 454
SPS: 592
...
SPS: 3479
SPS: 3513

The runs above are ran with a vector env with 512 sub envs.

erikfrey commented 2 years ago

@sgillen Thanks for putting that together. Looking through the stable-baselines3 code, I think the main crux is to implement something like DeviceRolloutBuffer. Building that and then plumbing it through the code would greatly increase SPS - this would be useful not just for brax but any other simulators that use accelerators. Maybe something we can talk to them about in the future.

@vwxyzjn I got your code working just fine in colab, thank you for setting that up! I did have to make one change: you hardcoded '4096' into the brax envs batch size, which was causing the code down below to complain about shape. I just made this change:

gym_env = gym.make("brax-ant-v0", batch_size=args.num_envs)

And now the code runs great. Device copying latency can be a killer so it does not surprise me that at batch size 128, CPU perf is still better. This is a great test harness, thanks for setting it up. I'll use this as a base to dive in.

erikfrey commented 2 years ago

@vwxyzjn also in case you were wondering, the increasingly higher SPS is because the first call to reset() and step() is where the compilation happens. JIT compilation is pretty slow. If you want the "stable" SPS, you can add something like this:

    # env warmup
    next_obs = envs.reset()
    next_obs, reward, done, info = envs.step(actions[0])
    next_obs = envs.reset()

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    ...
vwxyzjn commented 2 years ago

@erikfrey thanks for the reply! The SPS thing makes sense to me.

I was able to get the notebook working in the CPU mode after incorporating your suggested fix

gym_env = gym.make("brax-ant-v0", batch_size=args.num_envs)

However, I still have trouble running under the GPU runtime with the same error presented in the screenshot above. Did you manage to get the notebook working under the GPU runtime?

UnfilteredStackTrace: RuntimeError: INTERNAL: Failed to launch CUDA kernel: fusion_162 with block dimensions: 128x1x1 and grid dimensions: 1x1x1: CUDA_ERROR_OUT_OF_MEMORY: out of memory
erikfrey commented 2 years ago

@vwxyzjn hmm, yes I am using a GPU runtime, and I am not able to reproduce that issue. Only two things I can think of:

import jax
print(jax.devices()[0].device_kind)

I get assigned a Tesla K80.

vwxyzjn commented 2 years ago

@erikfrey the assigned a GPU seems to be the difference. I had the pro subscription and it was giving me a Tesla P100-PCIE-16GB, I switched back to a normal account, got assigned a Tesla K80 and it worked fine.

I also tested out the env.render('rgb_array') API and it works great with the PR https://github.com/google/brax/pull/84

image

vwxyzjn commented 2 years ago

As a sidenote, rendering images does significantly slow down the throughput. If rendering HTML is faster, I personally would prefer doing that instead...

This could also be achieved with a wrapper called RecordHTML that collects the rollouts from the first sub environment of the vector env, and by the end of the training, it outputs an HTML labeled by the episode like (basically doing HTML(html.render(env.sys, [s.qp for s in rollout])))

erikfrey commented 2 years ago

@vwxyzjn sure, we can make such a wrapper. Quick update: I think some part of that colab's PPO algorithm is still causing device copies - when I comment out the optimizing code block, SPS goes from 8k to 250k. I have more time tomorrow to look into why that's happening, but if you'd like to take a look in the meantime, please do let me know if anything obvious sticks out.

erikfrey commented 2 years ago

@vwxyzjn fyi I pulled some the remaining work items out into separate issues as this issue is getting large:

I'll update with progress on #88 and #89 we can do after that. Feel free to hop over to those issues to discuss more, and also see https://github.com/google/brax/projects/1 for what we're tracking overall for this effort.

vwxyzjn commented 2 years ago

Thank you @erikfrey! This is very exciting. I'll try to help as much as I can :)

erwincoumans commented 2 years ago

As a sidenote, rendering images does significantly slow down the throughput.

Yes, the CPU pytinyrenderer is not intended to use during training, only afterwards to see the rollout.

2 things to make it much faster: half the resolution (width and height), and disable anti-aliasing

Image(image.render(env.sys, [s.qp for s in rollout], width=160, height=120, ssaa=1))

@erikfrey What is needed to make those efficiency changes when using a Gym environment wrapper?

Perhaps add some members to the Gym wrappers to tune width, height and ssaa (instead of hardcoded 256,256)?