FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
442 stars 80 forks source link

Understanding some RNN logic for QMIX SMAX algorithm #119

Open Chulabhaya opened 3 weeks ago

Chulabhaya commented 3 weeks ago

Hey all, hope everyone is doing well! What follows may be bit of a dumb question, but I just wanted to clarify how this is working for my own algorithm development based on your guys' excellent code.

The QMIX code uses a ScannedRNN class where you pass in a sequence of observations and dones, and anywhere where a done condition is true, the hidden state is reset, and we pass the corresponding obs at that timestep through:

class ScannedRNN(nn.Module):
    @partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        hidden_size = ins.shape[-1]
        rnn_state = jnp.where(
            resets[:, np.newaxis],
            self.initialize_carry(hidden_size, *ins.shape[:-1]),
            rnn_state,
        )
        new_rnn_state, y = nn.GRUCell(hidden_size)(rnn_state, ins)
        return new_rnn_state, y

This makes sense to me. However I noticed that when data is actually collected, any given timestep actually consists of the last obs + new done, instead of last obs + last done.

image

Therefore, doesn't it mean that when we're using this RNN and it resets the hidden state and then passes in the observation, we're actually using the previous observation (which is associated with the previous episode) as the first step in the RNN's new sequence with the reset hidden state, instead of the current/new observation (from the new episode after the environment was just reset) generated after the episode was ended with the done being True?

Chulabhaya commented 3 weeks ago

To follow up on this, I looked through the SMAX MAPPO code, and it looks like there a given timestep consists of both last done and last obs, and I think this is correct because when you then sample a batch of this data and feed it into the same ScannedRNN, the dones will correctly reset the hidden states because the observations that are then passed in afterwards are the new observations from the new episodes. So is one of these implementations correct and the other one is wrong, or are they done differently for intentional reasons?

def _env_step(runner_state, unused):
      train_states, env_state, last_obs, last_done, hstates, rng = (
          runner_state
      )

      # SELECT ACTION
      rng, _rng = jax.random.split(rng)
      avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
      avail_actions = jax.lax.stop_gradient(
          batchify(avail_actions, env.agents, config["NUM_ACTORS"])
      )
      obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
      ac_in = (
          obs_batch[np.newaxis, :],
          last_done[np.newaxis, :],
          avail_actions,
      )
      # print('env step ac in', ac_in)
      ac_hstate, pi = actor_network.apply(
          train_states[0].params, hstates[0], ac_in
      )
      action = pi.sample(seed=_rng)
      log_prob = pi.log_prob(action)
      env_act = unbatchify(
          action, env.agents, config["NUM_ENVS"], env.num_agents
      )
      env_act = {k: v.squeeze() for k, v in env_act.items()}

      # VALUE
      # output of wrapper is (num_envs, num_agents, world_state_size)
      # swap axes to (num_agents, num_envs, world_state_size) before reshaping to (num_actors, world_state_size)
      world_state = last_obs["world_state"].swapaxes(0, 1)
      world_state = world_state.reshape((config["NUM_ACTORS"], -1))

      cr_in = (
          world_state[None, :],
          last_done[np.newaxis, :],
      )
      cr_hstate, value = critic_network.apply(
          train_states[1].params, hstates[1], cr_in
      )

      # STEP ENV
      rng, _rng = jax.random.split(rng)
      rng_step = jax.random.split(_rng, config["NUM_ENVS"])
      obsv, env_state, reward, done, info = jax.vmap(
          env.step, in_axes=(0, 0, 0)
      )(rng_step, env_state, env_act)
      info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
      done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
      transition = Transition(
          jnp.tile(done["__all__"], env.num_agents),
          last_done,
          action.squeeze(),
          value.squeeze(),
          batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
          log_prob.squeeze(),
          obs_batch,
          world_state,
          info,
          avail_actions,
      )
      runner_state = (
          train_states,
          env_state,
          obsv,
          done_batch,
          (ac_hstate, cr_hstate),
          rng,
      )
      return runner_state, transition
amacrutherford commented 2 weeks ago

@mttga mind checking this, pretty sure the mappo way is correct as we had to fix to use last_done

Chulabhaya commented 2 weeks ago

@amacrutherford Thanks for the follow up, I'm happy to make a PR with the update if you guys determine it is a bug. I also ran a couple experiments using last_dones and I got better performance on 2/3 (and equal on the third).