smearle / control-pcgrl

Train or evolve controllable and diverse level-generators.
MIT License
42 stars 14 forks source link

MultiAgentWrapper raises AssertionError during call to step #6

Open rohin-dasari opened 1 year ago

rohin-dasari commented 1 year ago

Expected Behavior

When running train_pcgrl multiagent.n_agents=2, the MultiAgentWrapper step method should execute without error.

Actual Behavior

When running train_pcgrl multiagent.n_agents=2, the environment is wrapped in a MultiAgentWrapper. When calling the step function of the wrapper, we call the step method of the parent class.

def step(self, action):
    obs, rew, done, info = {}, {}, {}, {}
    for k, v in action.items():
        self.unwrapped._rep.set_active_agent(k)
        obs_k, rew[k], done[k], info[k] = super().step(action={k: v}). # THIS LINE HERE
        obs.update(obs_k)
    done['__all__'] = np.all(list(done.values()))
    return obs, rew, done, info

The step method is expected to return the observations for a single agent, however, the parent's class observation space contains both agents. This raises the following exception:

AssertionError: The obs returned by the `step()` method observation keys is not same as the observation space keys, obs keys: ['agent_0'], space keys: ['agent_0', 'agent_1']

Environment

Ray 2.1.0 gym 0.26.2

Potential Fix

This implementation of the MultiAgentWrapper calls the step method of the parent class. Instead, we could call the step method of the unwrapped class. Note that this is not the base environment, this would just be one layer deeper:

    def step(self, action):
        obs, rew, done, info = {}, {}, {}, {}
        for k, v in action.items():
            self.unwrapped._rep.set_active_agent(k)
            obs_k, rew[k], done[k], info[k] = self.unwrapped.step(action={k: v}) ## CHANGED THIS LINE
            obs.update(obs_k)

        done['__all__'] = np.all(list(done.values()))

        return obs, rew, done, info

Note that since the representation is also wrapped, we can still pass in a action dictionary to the environment.

smearle commented 1 year ago

unwrapped does seem to return the base environment:

>>> self
<MultiAgentWrapper<ControlWrapper<CroppedImagePCGRLWrapper<ToImage<OneHotEncoding<Cropped<OrderEnforcing<PassiveEnvChecker<PcgrlCtrlEnv<binary-turtle-v0>>>>>>>>>>
>>> self.unwrapped
<control_pcgrl.envs.pcgrl_ctrl_env.PcgrlCtrlEnv object at 0x28f728940>

Calling self.unwrapped.step in your fix above works because now you're bypassing PassiveEnvChecker, and none of the wrappers in between seem to modify the action, but I worry it could lead to trouble later on.

Still a hack, but maybe have the env return dummy observations for other agents (then ignore them in the MultiAgentWrapper)?

rohin-dasari commented 1 year ago

That would require changes to the base environment, correct? Another hack I found:

def step(self, action):
    obs, rew, done, info = {}, {}, {}, {}

    try:
        sample_action = self.action_space.sample()
        action = {'agent_0': sample_action['agent_0']}
        super().step(action)
    except AssertionError:
        pass

    for k, v in action.items():
        self.unwrapped._rep.set_active_agent(k)
        obs_k, rew[k], done[k], info[k] = super().step(action={k: v})
        obs.update(obs_k)
    done['__all__'] = np.all(list(done.values()))
    return obs, rew, done, info

It looks like the assertion is only raised during the first execution of the step method. So this solution catches it and then pretends like it never happened, effectively bypassing the PassiveEnvChecker. Definitely not the best solution, but it maintains the wrappers around the environment.

Update: This works when using the debug flag, but fails with the same assertion error when starting a real training run

rohin-dasari commented 1 year ago

Can you share the exact command you were running for your multiagent runs? Maybe there was a flag that was set that I'm missing or something.

rohin-dasari commented 1 year ago

Ah, after messing around for a bit, I found the problem. It looks like it was loading the old checkpoint that was generated from a single agent run. I added in the load=False flag and everything seems to be working now. Closing the issue now.

rohin-dasari commented 1 year ago

Sorry, made a mistake, the error is actually still raised with the load=False flag set. I was able to get it working with this fix though:

def disable_passive_env_checker(env):
    # remove the passive environment checker wrapper from the env attribute of an env
    # base case -> the environment is not a wrapper
    if not hasattr(env, 'env'):
        return env

    root = env
    prev = env 
    while hasattr(prev, 'env'):
        next_ = prev.env
        if isinstance(next_, gym.wrappers.env_checker.PassiveEnvChecker):
            prev.env = next_.env
        prev = next_

    return root

class MultiAgentWrapper(gym.Wrapper, MultiAgentEnv):
    def __init__(self, game, **kwargs):
        multiagent_args = kwargs.get('multiagent')
        self.env = disable_passive_env_checker(game) # DISABLE ENVIRIONMENT CHECKING
        gym.Wrapper.__init__(self, self.env)
        MultiAgentEnv.__init__(self.env)
        self.n_agents = multiagent_args.get('n_agents', 2)
        self.observation_space = gym.spaces.Dict({})
        self.action_space = gym.spaces.Dict({})
        for i in range(self.n_agents):
            self.observation_space.spaces[f'agent_{i}'] = self.env.observation_space
            self.action_space.spaces[f'agent_{i}'] = self.env.action_space
        # otherwise gym utils throws an error???
        self.unwrapped.observation_space = self.observation_space
        self.unwrapped.action_space = self.action_space

    def reset(self):
        obs = super().reset()
        return obs

    def step(self, action):
        obs, rew, done, info = {}, {}, {}, {}

        for k, v in action.items():
            self.unwrapped._rep.set_active_agent(k)
            obs_k, rew[k], done[k], info[k] = super().step(action={k: v})
            obs.update(obs_k)

        done['__all__'] = np.all(list(done.values()))

        return obs, rew, done, info

Implemented a function to remove the passive environment checker wrapper. Everything seems to be working now, but likely not a long term solution.

Here's the command the reproduces the error on Echidna:

python bin/train_pcgrl multiagent.n_agents=2 load=False

After adding in the disable_passive_env_checker function, training proceeds as expected, but there is a warning raised that mentions that the passive environment checker wrapper is missing.

smearle commented 1 year ago

That sounds ok for now. (And you're right that we don't want to modify the underlying env or we'd break the single-agent setting.) Maybe the most proper thing to do would be to have PassiveEnvChacker as the outermost wrapper, but not sure if this is possible.

smearle commented 1 year ago

Alternatively, we could do gym.make(env, disable_env_checker=True) in the handful of places where we call gym.make.

rohin-dasari commented 1 year ago

Ah, didn't realize that parameter existed. Good call. I'll test it out and make a PR in a bit