oxwhirl / smacv2

MIT License
211 stars 32 forks source link

[Bug] Example script bugs #28

Open matteobettini opened 1 year ago

matteobettini commented 1 year ago
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from os import replace

from smac.env import StarCraft2Env
import numpy as np
from absl import logging
import time

from smac.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper

logging.set_verbosity(logging.DEBUG)

def main():

    distribution_config = {
        "n_units": 5,
        "team_gen": {
            "dist_type": "weighted_teams",
            "unit_types": ["marine", "marauder", "medivac"],
            "exception_unit_types": ["medivac"],
            "weights": [0.45, 0.45, 0.1],
            "observe": True,
        },
        "start_positions": {
            "dist_type": "surrounded_and_reflect",
            "p": 0.5,
            "n_enemies": 5,
            "map_x": 32,
            "map_y": 32,
        },
    }
    env = StarCraftCapabilityEnvWrapper(
        capability_config=distribution_config,
        map_name="10gen_terran",
        debug=True,
        conic_fov=False,
        obs_own_pos=True,
        use_unit_ranges=True,
        min_attack_range=2,
    )

    env_info = env.get_env_info()

    n_actions = env_info["n_actions"]
    n_agents = env_info["n_agents"]

    n_episodes = 10

    print("Training episodes")
    for e in range(n_episodes):
        env.reset()
        terminated = False
        episode_reward = 0

        while not terminated:
            obs = env.get_obs()
            state = env.get_state()
            # env.render()  # Uncomment for rendering

            actions = []
            for agent_id in range(n_agents):
                avail_actions = env.get_avail_agent_actions(agent_id)
                avail_actions_ind = np.nonzero(avail_actions)[0]
                action = np.random.choice(avail_actions_ind)
                actions.append(action)

            reward, terminated, _ = env.step(actions)
            time.sleep(0.15)
            episode_reward += reward
        print("Total reward in episode {} = {}".format(e, episode_reward))

if __name__ == "__main__":
    main()

Bugs:

  File "/Users/matbet/PycharmProjects/rl/torchrl/envs/libs/smac_try.py", line 77, in <module>
    main()
  File "/Users/matbet/PycharmProjects/rl/torchrl/envs/libs/smac_try.py", line 35, in main
    env = StarCraftCapabilityEnvWrapper(
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/smacv2/env/starcraft2/wrapper.py", line 10, in __init__
    self._parse_distribution_config()
  File "/Users/matbet/miniconda3/envs/torchrl/lib/python3.9/site-packages/smacv2/env/starcraft2/wrapper.py", line 24, in _parse_distribution_config
    config["n_enemies"] = self.distribution_config["n_enemies"]
KeyError: 'n_enemies'

n_enemies in the config seems to be in the wrong place?