corl-team / xland-minigrid

JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid 🏎️
Apache License 2.0
192 stars 15 forks source link

The rollout speed is slower than gymnasium #17

Closed alexxchen closed 6 months ago

alexxchen commented 6 months ago

I am new to jax. I can't see why it is extremely slow when I run the example code on cpu. The wall time is so long for each step.

Here is my code

import jax.random
import xminigrid
import time
from xminigrid.benchmarks import Benchmark

from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper

num_envs = 8
benchmark = xminigrid.load_benchmark(name="trivial-1m")

rng = jax.random.PRNGKey(0)
ruleset_rng = jax.random.split(rng, num=num_envs)
reset_rng = jax.random.split(rng, num=num_envs)
train = jax.vmap(benchmark.sample_ruleset)(ruleset_rng)

def make_env(rulesets, img_obs=False):

    env, env_params = xminigrid.make("XLand-MiniGrid-R1-9x9")
    env_params = env_params.replace(ruleset=rulesets)

    env = GymAutoResetWrapper(env)

    if img_obs:
        # render obs as rgb images if needed (warn: this will affect speed greatly)
        env = RGBImgObservationWrapper(env)
    return env, env_params

train_env, train_params = make_env(train)

timestep = jax.vmap(train_env.reset, in_axes=(0, 0))(train_params, reset_rng)
start = time.time()
for i in range(10):
    timestep = jax.vmap(train_env.step, in_axes=0)(train_params, timestep, action=jax.numpy.zeros((8,), dtype=jax.numpy.uint8).squeeze())
print(time.time() - start)

It take 30 seconds to run 10 steps. The minigrid environment in gymnasium only takes 0.3 seconds for 1000 steps

import gymnasium as gym
import time
# env = gym.make("MiniGrid-Empty-5x5-v0")
env = gym.make("MiniGrid-Playground-v0")
observation, info = env.reset(seed=42)
start = time.time()
for _ in range(1000):
   action = env.action_space.sample()
   observation, reward, terminated, truncated, info = env.step(action)

   if terminated or truncated:
      observation, info = env.reset()
env.close()

print(time.time() - start)
Howuhh commented 6 months ago

Hi @alexxchen! In jax, you should try to write everything in such a way that it can be used together with jax.jit. Without it, the speed will often be even worse than that of regular numpy, especially on cpu.

For example, see how I wrote evaluation function on the baselines: https://github.com/corl-team/xland-minigrid/blob/b21c1424f9b66546146188ea74cfc952fba5f888/training/utils.py#L110

After jitting it will be a lot faster. Also, it is better and easier to write code for rollout for one env and apply vmap only afterwards.

Howuhh commented 6 months ago

@alexxchen, As the question is more about general knowledge of JAX rather than a specifics of XLand-MiniGrid, I would suggest practising on simpler examples or looking at the documentation first: https://jax.readthedocs.io/en/latest/jax-101/index.html

alexxchen commented 6 months ago

@Howuhh I see where is the problem. The speed gets normal after adding this line `reset_fn, step_fn = jax.jit(env.reset), jax.jit(env.step)' Your information is useful, thank you very much!

Howuhh commented 6 months ago

@alexxchen Yup! I think it will be even faster if you jit the entire rollout function, not only step/reset. Also, do not forget to jit the model itself.