keraJLi / rejax

Apache License 2.0
146 stars 7 forks source link

Running vmapped policies #15

Closed remmyzen closed 1 month ago

remmyzen commented 1 month ago

Hi,

Thanks for the nice repo! I am trying to run the train vmapped policies example in the README.md but always face an error. Could you give an example?

Basically, I am trying to combine the examples in the README.md and the examples/rejax_tour.ipynb notebook.

### Taken from README.md
from rejax import PPO
import jax

env_str = "CartPole-v1"
# Get train function and initialize config for training
algo = PPO.create(env=env_str, learning_rate=0.001)

# Jit the training function
train_fn = jax.jit(algo.train)

# Vmap training function over 300 initial seeds
vmapped_train_fn = jax.vmap(train_fn)

# Train 300 agents!
keys = jax.random.split(jax.random.PRNGKey(0), 50)
train_state, evaluation = vmapped_train_fn(keys)

# Get policy and jit it
policy = algo.make_act(train_state)
policy = jax.jit(policy)

### Taken from rejax_tour
# For demonstration purposes, we do a manual rollout of the policy
import gymnax

rng = jax.random.PRNGKey(0)
env, params = gymnax.make(env_str)
step = jax.jit(env.step)

obs, state = env.reset(rng, params)
episode_return = 0
done = False

while not done:
    rng, rng_action, rng_step = jax.random.split(rng, 3)
    action = policy(obs, rng_action)
    obs, state, reward, done, info = step(rng_step, state, action, params)
    episode_return += reward

print(f"Return achieved in one episode of {env_str}: {episode_return}")

It gives an error on the action = policy(obs, rng_action) line with this message:

ScopeParamShapeError: Initializer expected to generate shape (50, 4, 64) but got shape (4, 64) instead for parameter "kernel" in "/features/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

I also tried to vmap the reset such that the obs.shape is (50,4) and split the rng_action into 50 but then the error showed got shape (200, 64) possibly due to the squeeze in the policy function.

Could you advise on how to run vmapped policy properly?

Thank you.

keraJLi commented 1 month ago

Hi @remmyzen! The issue seems to be with the flattening of observations in the policy network. There should be an easy way to achieve what you want to do using something like vmap on make_act, I will address this in the future. For now, I believe an easy workaround is to vmap the policy along the first axis of the train state as well, like this:

@jax.jit
@jax.vmap
def policy(train_state, obs, rng):
    p = algo.make_act(train_state)
    return p(obs, rng)

Since you want to run all policies, you have to vmap step and reset as well, and take some care with the loop termination:

rng = jax.random.PRNGKey(0)
env, params = gymnax.make(env_str)
step = jax.vmap(env.step, in_axes=(0, 0, 0, None))
step = jax.jit(step)

rng, rng_init = jax.random.split(rng)
rng_init = jax.random.split(rng_init, len(keys))
obs, state = jax.vmap(env.reset, in_axes=(0, None))(rng_init, params)

episode_return = 0
done = jax.numpy.zeros((len(keys),), dtype=bool)

while not jax.numpy.all(done):
    rng, rng_action, rng_step = jax.random.split(rng, 3)
    rng_action = jax.random.split(rng_action, len(keys))
    rng_step = jax.random.split(rng_step, len(keys))

    action = policy(train_state, obs, rng_action)
    obs, state, reward, new_done, info = step(rng_step, state, action, params)
    episode_return += reward * (1 - done)
    done = done | new_done

print(f"Return achieved in one episode of {env_str}: {episode_return}")

Please note that the rollout using a for-loop is only meant for didactic purposes, using jit in combination with jax.lax.while_loop is much faster.

Let me know if this works for you!

remmyzen commented 1 month ago

Hi @keraJLi,

Thank you for your reply. It works now!