Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax
brings the power of jit
and vmap
/pmap
to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax
allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax
). We provide training & checkpoints for both PPO & ES in gymnax-blines
. Get started here š .
gymnax
API Usage š²import jax
import gymnax
rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)
# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")
# Reset the environment.
obs, state = env.reset(key_reset, env_params)
# Sample a random action.
action = env.action_space(env_params).sample(key_act)
# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit
compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines
documentation.
The latest gymnax
release can directly be installed from PyPI:
pip install gymnax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/gymnax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
gymnax
API.SpaceInvaders-MinAtar
.gymnax
- Meta-evolve an LSTM controller that controls 2 link pendula of different lengths.gymnax-blines
.Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit
, vmap
, pmap
):
# Jit-accelerated step transition
jit_step = jax.jit(env.step)
# map (vmap/pmap) across random keys for batch rollouts
reset_rng = jax.vmap(env.reset, in_axes=(0, None))
step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))
# map (vmap/pmap) across env parameters (e.g. for meta-learning)
reset_params = jax.vmap(env.reset, in_axes=(None, 0))
step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
For speed comparisons with standard vectorized NumPy environments check out gymnax-blines
.
Scan through entire episode rollouts: You can also lax.scan
through entire reset
, step
episode loops for fast compilation:
def rollout(rng_input, policy_params, env_params, steps_in_episode):
"""Rollout a jitted gymnax episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = env.reset(rng_reset, env_params)
def policy_step(state_input, tmp):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng = state_input
rng, rng_step, rng_net = jax.random.split(rng, 3)
action = model.apply(policy_params, obs)
next_obs, next_state, reward, done, _ = env.step(
rng_step, state, action, env_params
)
carry = [next_obs, next_state, policy_params, rng]
return carry, [obs, action, reward, next_obs, done]
# Scan over episode step loop
_, scan_out = jax.lax.scan(
policy_step,
[obs, state, policy_params, rng_episode],
(),
steps_in_episode
)
# Return masked sum of rewards accumulated by agent in episode
obs, action, reward, next_obs, done = scan_out
return obs, action, reward, next_obs, done
Build-in visualization tools: You can also smoothly generate GIF animations using the Visualizer
tool, which covers all classic_control
, MinAtar
and most misc
environments:
from gymnax.visualize import Visualizer
state_seq, reward_seq = [], []
rng, rng_reset = jax.random.split(rng)
obs, env_state = env.reset(rng_reset, env_params)
while True:
state_seq.append(env_state)
rng, rng_act, rng_step = jax.random.split(rng, 3)
action = env.action_space(env_params).sample(rng_act)
next_obs, next_env_state, reward, done, info = env.step(
rng_step, env_state, action, env_params
)
reward_seq.append(reward)
if done:
break
else:
obs = next_obs
env_state = next_env_state
cum_rewards = jnp.cumsum(jnp.array(reward_seq))
vis = Visualizer(env, env_params, state_seq, cum_rewards)
vis.animate(f"docs/anim.gif")
Training pipelines & pretrained agents: Check out gymnax-blines
for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.
Simple batch agent evaluation: Work-in-progress.
from gymnax.experimental import RolloutWrapper
# Define rollout manager for pendulum env
manager = RolloutWrapper(model.apply, env_name="Pendulum-v1")
# Simple single episode rollout for policy
obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
# Multiple rollouts for same network (different rng, e.g. eval)
rng_batch = jax.random.split(rng, 10)
obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
rng_batch, policy_params
)
# Multiple rollouts for different networks + rng (e.g. for ES)
batch_params = jax.tree_map( # Stack parameters or use different
lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params
)
obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
rng_batch, batch_params
)
gymnax
āļøIf you use gymnax
in your research, please cite it as follows:
@software{gymnax2022github,
author = {Robert Tjarko Lange},
title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
url = {http://github.com/RobertTLange/gymnax},
version = {0.0.4},
year = {2022},
}
We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.
You can run the test suite via python -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing š¤.