google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.34k stars 255 forks source link

Single environment performance on GPU worse than on CPU #246

Closed nico-bohlinger closed 1 year ago

nico-bohlinger commented 2 years ago

I want to use a single(!) ant environment for my project and checked the steps per seconds with this script:

import time
import functools
import gym
from brax import envs
import jax

jax.config.update("jax_platform_name", "cpu")

entry_point = functools.partial(envs.create_gym_env, env_name='ant')
gym.register('brax-ant-v0', entry_point=entry_point)

nr_envs = 1
nr_steps = 1000
total_steps = nr_envs * nr_steps
gym_env = gym.make("brax-ant-v0", batch_size=nr_envs)

obs = gym_env.reset()

key = jax.random.PRNGKey(42)

key, subkey = jax.random.split(key)
jax.random.uniform(subkey, gym_env.action_space.shape)

jit_env_step = jax.jit(gym_env.step)
obs, reward, done, info = jit_env_step(gym_env.action_space.sample())

before = time.time()

for _ in range(nr_steps):
    key, subkey = jax.random.split(key)
    action = jax.random.uniform(subkey, gym_env.action_space.shape)
    obs, rewards, done, info = jit_env_step(action)

duration = time.time() - before
print(f'time for {total_steps} steps: {duration:.2f}s ({int(total_steps / duration)} steps/sec)')

Using my CPU for jitting I get around 2800 steps per second, which is better than what I would get with a MuJoCo ant environment. But if I comment out the line for putting jax in CPU mode and use my GPU I only get 700 steps per second. Jitting with the GPU is amazing for many environments but seems to be not that great with only one environment. On the other hand using the GPU is far better for the RL algorithm I want to use later

Is there a way of improving single environment performance on GPU? Or is there maybe a way of efficiently jitting the environment with the CPU and still be using the GPU for my RL algorithm in Jax?

btaba commented 2 years ago

Hi @nico-bohlinger , I think the lower steps-per-sec is somewhat expected for a single environment, but you'd be able to batch the envs on GPU and get a higher overall steps-per-sec. You will want to ask on the gym repo how to batch the envs, but in brax it would look something like:

env = VmapWrapper(env)
rngs = jax.random.split(jax.random.PRNGKey(seed), batch_size)
env.reset(rng)
jit_env_step = jax.jit(env.step)
jit_env_step(action)  # action is [batch_size, n_action]

You likely want to keep the RL algo running on the same device as the env.step to avoid data transfers btwn the host and the gpu/tpu device

The agents in https://github.com/google/brax/tree/main/brax/training/agents are a great reference