Closed nico-bohlinger closed 1 year ago
Hi @nico-bohlinger , I think the lower steps-per-sec is somewhat expected for a single environment, but you'd be able to batch the envs on GPU and get a higher overall steps-per-sec. You will want to ask on the gym repo how to batch the envs, but in brax it would look something like:
env = VmapWrapper(env)
rngs = jax.random.split(jax.random.PRNGKey(seed), batch_size)
env.reset(rng)
jit_env_step = jax.jit(env.step)
jit_env_step(action) # action is [batch_size, n_action]
You likely want to keep the RL algo running on the same device as the env.step to avoid data transfers btwn the host and the gpu/tpu device
The agents in https://github.com/google/brax/tree/main/brax/training/agents are a great reference
I want to use a single(!) ant environment for my project and checked the steps per seconds with this script:
Using my CPU for jitting I get around 2800 steps per second, which is better than what I would get with a MuJoCo ant environment. But if I comment out the line for putting jax in CPU mode and use my GPU I only get 700 steps per second. Jitting with the GPU is amazing for many environments but seems to be not that great with only one environment. On the other hand using the GPU is far better for the RL algorithm I want to use later
Is there a way of improving single environment performance on GPU? Or is there maybe a way of efficiently jitting the environment with the CPU and still be using the GPU for my RL algorithm in Jax?