Open Chulabhaya opened 3 weeks ago
To follow up on this, I looked through the SMAX MAPPO code, and it looks like there a given timestep consists of both last done and last obs, and I think this is correct because when you then sample a batch of this data and feed it into the same ScannedRNN, the dones will correctly reset the hidden states because the observations that are then passed in afterwards are the new observations from the new episodes. So is one of these implementations correct and the other one is wrong, or are they done differently for intentional reasons?
def _env_step(runner_state, unused):
train_states, env_state, last_obs, last_done, hstates, rng = (
runner_state
)
# SELECT ACTION
rng, _rng = jax.random.split(rng)
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
avail_actions = jax.lax.stop_gradient(
batchify(avail_actions, env.agents, config["NUM_ACTORS"])
)
obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
ac_in = (
obs_batch[np.newaxis, :],
last_done[np.newaxis, :],
avail_actions,
)
# print('env step ac in', ac_in)
ac_hstate, pi = actor_network.apply(
train_states[0].params, hstates[0], ac_in
)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(
action, env.agents, config["NUM_ENVS"], env.num_agents
)
env_act = {k: v.squeeze() for k, v in env_act.items()}
# VALUE
# output of wrapper is (num_envs, num_agents, world_state_size)
# swap axes to (num_agents, num_envs, world_state_size) before reshaping to (num_actors, world_state_size)
world_state = last_obs["world_state"].swapaxes(0, 1)
world_state = world_state.reshape((config["NUM_ACTORS"], -1))
cr_in = (
world_state[None, :],
last_done[np.newaxis, :],
)
cr_hstate, value = critic_network.apply(
train_states[1].params, hstates[1], cr_in
)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0)
)(rng_step, env_state, env_act)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
last_done,
action.squeeze(),
value.squeeze(),
batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
log_prob.squeeze(),
obs_batch,
world_state,
info,
avail_actions,
)
runner_state = (
train_states,
env_state,
obsv,
done_batch,
(ac_hstate, cr_hstate),
rng,
)
return runner_state, transition
@mttga mind checking this, pretty sure the mappo way is correct as we had to fix to use last_done
@amacrutherford Thanks for the follow up, I'm happy to make a PR with the update if you guys determine it is a bug. I also ran a couple experiments using last_dones
and I got better performance on 2/3 (and equal on the third).
Hey all, hope everyone is doing well! What follows may be bit of a dumb question, but I just wanted to clarify how this is working for my own algorithm development based on your guys' excellent code.
The QMIX code uses a ScannedRNN class where you pass in a sequence of observations and dones, and anywhere where a done condition is true, the hidden state is reset, and we pass the corresponding obs at that timestep through:
This makes sense to me. However I noticed that when data is actually collected, any given timestep actually consists of the last obs + new done, instead of last obs + last done.
Therefore, doesn't it mean that when we're using this RNN and it resets the hidden state and then passes in the observation, we're actually using the previous observation (which is associated with the previous episode) as the first step in the RNN's new sequence with the reset hidden state, instead of the current/new observation (from the new episode after the environment was just reset) generated after the episode was ended with the done being True?