ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.45k stars 5.67k forks source link

ValueError: Must set agent_id on policy config #39246

Open ajvish91 opened 1 year ago

ajvish91 commented 1 year ago
import os

import ray
import supersuit as ss
from ray import tune
from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from torch import nn

from pettingzoo.butterfly import knights_archers_zombies_v10

class CNNModelV2(TorchModelV2, nn.Module):
    def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs):
        TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
        nn.Module.__init__(self)
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, [8, 8], stride=(4, 4)),
            nn.ReLU(),
            nn.Conv2d(32, 64, [4, 4], stride=(2, 2)),
            nn.ReLU(),
            nn.Conv2d(64, 64, [3, 3], stride=(1, 1)),
            nn.ReLU(),
            nn.Flatten(),
            (nn.Linear(3136, 512)),
            nn.ReLU(),
        )
        self.policy_fn = nn.Linear(512, num_outputs)
        self.value_fn = nn.Linear(512, 1)

    def forward(self, input_dict, state, seq_lens):
        model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2))
        self._value_out = self.value_fn(model_out)
        return self.policy_fn(model_out), state

    def value_function(self):
        return self._value_out.flatten()

def env_creator(args):
    env = knights_archers_zombies_v10.parallel_env(
        max_cycles=100,
        max_zombies=4,
        vector_state=False
    )
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.dtype_v0(env, "float32")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.normalize_obs_v0(env, env_min=0, env_max=1)
    env = ss.frame_stack_v1(env, 3)
    return env

if __name__ == "__main__":
    ray.init()

    env_name = "knights_archers_zombies_v10"

    register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
    ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2)

    action_space = knights_archers_zombies_v10.env().action_space("archer_1")

    observer_space = knights_archers_zombies_v10.env().observation_space("archer_1")
    print(action_space, observer_space)

    config = (
        MADDPGConfig()
        .environment(env=env_name, clip_actions=True, disable_env_checking=True)
        .rollouts(num_rollout_workers=4, rollout_fragment_length=128)
        .multi_agent(
            policies={
                "archer0pol": (None, observer_space, action_space, {}),
                "archer1pol": (None, observer_space, action_space, {}),
                "knight0pol": (None, observer_space, action_space, {}),
                "knight1pol": (None, observer_space, action_space, {}),
                "poldef": (None, observer_space, action_space, {}),
            },
            policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "archer0pol"
            if agent_id == "archer_0"
            else lambda agent_id, episode, worker, **kwargs: "archer1pol"
            if agent_id == "archer_1"
            else lambda agent_id, episode, worker, **kwargs: "knight0pol"
            if agent_id == "knight_0"
            else lambda agent_id, episode, worker, **kwargs: "knight1pol"
            if agent_id == "knight_1"
            else "poldef",
        )
        .training(
            train_batch_size=128,
            lr=3e-4,
            gamma=0.99
        )
        .debugging(log_level="ERROR")
        .framework(framework="torch")
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
    )

    tune.run(
        "MADDPG",
        name="MADDPG",
        stop={"timesteps_total": 20000},
        checkpoint_freq=10,
        local_dir="~/ray_results/" + env_name,
        config=config.to_dict(),
    )

Leads to:

>ValueError: Must set `agent_id` in the policy config

Is my policy mapping okay? Getting the above value error.

ajvish91 commented 1 year ago

ray 3.0.0.dev0 pettingzoo 1.24.0