Open ajvish91 opened 1 year ago
import os import ray import supersuit as ss from ray import tune from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv from ray.rllib.models import ModelCatalog from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.tune.registry import register_env from torch import nn from pettingzoo.butterfly import knights_archers_zombies_v10 class CNNModelV2(TorchModelV2, nn.Module): def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs): TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs) nn.Module.__init__(self) self.model = nn.Sequential( nn.Conv2d(3, 32, [8, 8], stride=(4, 4)), nn.ReLU(), nn.Conv2d(32, 64, [4, 4], stride=(2, 2)), nn.ReLU(), nn.Conv2d(64, 64, [3, 3], stride=(1, 1)), nn.ReLU(), nn.Flatten(), (nn.Linear(3136, 512)), nn.ReLU(), ) self.policy_fn = nn.Linear(512, num_outputs) self.value_fn = nn.Linear(512, 1) def forward(self, input_dict, state, seq_lens): model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2)) self._value_out = self.value_fn(model_out) return self.policy_fn(model_out), state def value_function(self): return self._value_out.flatten() def env_creator(args): env = knights_archers_zombies_v10.parallel_env( max_cycles=100, max_zombies=4, vector_state=False ) env = ss.color_reduction_v0(env, mode="B") env = ss.dtype_v0(env, "float32") env = ss.resize_v1(env, x_size=84, y_size=84) env = ss.normalize_obs_v0(env, env_min=0, env_max=1) env = ss.frame_stack_v1(env, 3) return env if __name__ == "__main__": ray.init() env_name = "knights_archers_zombies_v10" register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config))) ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2) action_space = knights_archers_zombies_v10.env().action_space("archer_1") observer_space = knights_archers_zombies_v10.env().observation_space("archer_1") print(action_space, observer_space) config = ( MADDPGConfig() .environment(env=env_name, clip_actions=True, disable_env_checking=True) .rollouts(num_rollout_workers=4, rollout_fragment_length=128) .multi_agent( policies={ "archer0pol": (None, observer_space, action_space, {}), "archer1pol": (None, observer_space, action_space, {}), "knight0pol": (None, observer_space, action_space, {}), "knight1pol": (None, observer_space, action_space, {}), "poldef": (None, observer_space, action_space, {}), }, policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "archer0pol" if agent_id == "archer_0" else lambda agent_id, episode, worker, **kwargs: "archer1pol" if agent_id == "archer_1" else lambda agent_id, episode, worker, **kwargs: "knight0pol" if agent_id == "knight_0" else lambda agent_id, episode, worker, **kwargs: "knight1pol" if agent_id == "knight_1" else "poldef", ) .training( train_batch_size=128, lr=3e-4, gamma=0.99 ) .debugging(log_level="ERROR") .framework(framework="torch") .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) ) tune.run( "MADDPG", name="MADDPG", stop={"timesteps_total": 20000}, checkpoint_freq=10, local_dir="~/ray_results/" + env_name, config=config.to_dict(), )
Leads to:
>ValueError: Must set `agent_id` in the policy config
Is my policy mapping okay? Getting the above value error.
ray 3.0.0.dev0 pettingzoo 1.24.0
Leads to:
Is my policy mapping okay? Getting the above value error.