google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Brax V2 pbd is about 3x slower than V1 #371

Closed imoneoi closed 8 months ago

imoneoi commented 1 year ago

Why is the positional backend (pbd) in Brax v2 about 3x slower than v1 on humanoid? I observed that substep and dt may have differences, but v1 has a total of 533 pbd steps, and v2 has a total of 666 pbd steps, the difference should not be that big.

Here are the benchmark results:

JAX 0.4.8 @ RTX 3090 * 1

V2 Humanoid PBD

Pop size 1024 FPS 212603.0
Pop size 2048 FPS 341043.9
Pop size 4096 FPS 430086.4
Pop size 10240 FPS 411593.8

V1 Humanoid PBD

Pop size 1024 FPS 441713.6
Pop size 2048 FPS 873183.5
Pop size 4096 FPS 1237109.4
Pop size 10240 FPS 1310780.7

JAX 0.4.11 @ RTX 3090 * 1

V2 Humanoid PBD

Pop size 1024 FPS 199116.0
Pop size 2048 FPS 328199.3
Pop size 4096 FPS 415347.5
Pop size 10240 FPS 403450.6

V1 Humanoid PBD

Pop size 1024 FPS 467750.8
Pop size 2048 FPS 926400.9
Pop size 4096 FPS 1372671.0
Pop size 10240 FPS 1417486.8

The benchmark code is as following:

from functools import partial
import time

import jax
import jax.numpy as jnp

from brax import envs

def create_env(env_name, is_v2, max_steps_per_episode=1000):
    if is_v2:
        # V2 API
        env = envs.get_environment(env_name, backend="positional")
        env = envs.wrapper.EpisodeWrapper(env, max_steps_per_episode, 1)
        env = envs.wrapper.VmapWrapper(env)
    else:
        # V1 API
        env = envs.get_environment(env_name)
        env = envs.wrappers.EpisodeWrapper(env, max_steps_per_episode, 1)
        env = envs.wrappers.VmapWrapper(env)

    return env

@partial(jax.jit, static_argnames=["env", "pop_size", "steps"])
def benchmark(
    env,
    seed:     int = 0,
    pop_size: int = 10240,
    steps:    int = 1000,
):
    # Init state
    init_state_key, act_seq_key = jax.random.split(jax.random.PRNGKey(seed))
    init_state                  = env.reset(jax.random.split(init_state_key, pop_size))

    act_seq                     = jax.random.uniform(act_seq_key, (steps, pop_size, env.action_size), minval=-1, maxval=1)

    # Scan
    def _step_env(carry, act):
        return env.step(carry, act), None

    return jax.lax.scan(_step_env, init_state, act_seq)

def main():
    env_name  = "humanoid"
    is_v2     = False

    pop_sizes = [1024, 2048, 4096, 10240]
    steps     = 1000

    # create env
    env = create_env(env_name, is_v2)

    # bench
    for pop_size in pop_sizes:
        conf = dict(env=env, steps=steps, pop_size=pop_size)

        # JIT warmup
        result = benchmark(**conf)
        jax.tree_map(lambda x: x.block_until_ready(), result)

        # Test time
        t_start = time.time()

        result = benchmark(**conf)
        jax.tree_map(lambda x: x.block_until_ready(), result)

        fps = (pop_size * steps) / (time.time() - t_start)

        print(f"Pop size {pop_size} FPS {fps:.1f}")

if __name__ == "__main__":
    main()
btaba commented 1 year ago

Hi @imoneoi , thanks for running these benchmarks on an RTX!

Indeed there is a slow down compared to v1, here are a few notable reasons:

This is still a WIP! We have our eye on speeding things up for both generalized and positional, and we're open to contributions with notable speed-up gains

imoneoi commented 1 year ago

Thanks! I also noticed that v2 pbd may be more accurate than v1.

BTW, are there any recommended tools for profiling and locating bottlenecks in Brax?

JoeyTeng commented 1 year ago

Thanks! I also noticed that v2 pbd may be more accurate than v1.

BTW, are there any recommended tools for profiling and locating bottlenecks in Brax?

I am not sure about Brax specifically, but I think it would be good to start with Profiling JAX Programs? If you are using Colab, you can see google/jax#3694 as well.