hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

Custom gym Env Assertation error regarding reset () method #1173

Open sheila-janota opened 2 years ago

sheila-janota commented 2 years ago

Hello,

I am having some issues when checking my custom environment. I have checked the several solutions adopted and suggested by other people here, but they don't seem to solve the issue I'm having.

it says: AssertionError: The observation returned by the reset() method does not match the given observation space.

Here are the lines of codes I used to create my custom env:

class EnvWrapperSB2(gym.Env):
    def __init__(self, no_threads, **params):
        super(EnvWrapperSB2, self).__init__()
        #self.action_space = None
        #self.observation_space = None
        self.params = params
        self.no_threads = no_threads
        self.ports = [13968+i+np.random.randint(40000) for i in range(no_threads)]
        self.commands = self._craft_commands(params)
        #self.action_space = spaces.Discrete(1)
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=int)
        self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)

        self.SCRIPT_RUNNING = False
        self.envs = []

        self.run()
        for port in self.ports:
            env = ns3env.Ns3Env(port=port, stepTime=params['envStepTime'], startSim=0, simSeed=0, simArgs=params, debug=False)
            self.envs.append(env)

        self.SCRIPT_RUNNING = True

    def run(self):
        if self.SCRIPT_RUNNING:
            raise AlreadyRunningException("Script is already running")

        for cmd, port in zip(self.commands, self.ports):
            subprocess.Popen(['bash', '-c', cmd])
        self.SCRIPT_RUNNING = True

    def _craft_commands(self, params):
        try:
            waf_pwd = find_waf_path("./")
        except FileNotFoundError:
            import sys
            sys.path.append("../../")
            waf_pwd = find_waf_path("../../")

        command = f'{waf_pwd} --run "RLinWiFi-master-original-queue-size'
        for key, val in params.items():
            command+=f" --{key}={val}"

        commands = []
        for p in self.ports:
            commands.append(command+f' --openGymPort={p}"')

        return commands

    def reset(self):
        obs = []
        for env in self.envs:
            obs.append(env.reset())
        #print("reset - obs tamanho",len(obs))
        #print("reset - obs",obs)

        return np.array(obs)

    def step(self, actions):
        next_obs, reward, done, info = [], [], [], []

        for i, env in enumerate(self.envs):
            no, rew, dn, inf = env.step(actions[i].tolist())
            next_obs.append(no)
            reward.append(rew)
            done.append(dn)
            info.append(inf)

        return np.array(next_obs), np.array(reward), np.array(done), np.array(info)

    #@property
    #def observation_space(self):
     #   dim = repr(self.envs[0].observation_space).replace('(', '').replace(',)', '').split(", ")[2]
      #  return (self.no_threads, int(dim))

    #@property
    #def action_space(self):
     #   dim = repr(self.envs[0].action_space).replace('(', '').replace(',)', '').split(", ")[2]
      #  return (self.no_threads, int(dim))

    def close(self):
        time.sleep(5)
        for env in self.envs:
            env.close()
        # subprocess.Popen(['bash', '-c', "killall linear-mesh"])

        self.SCRIPT_RUNNING = False

    def __getattr__(self, attr):
        for env in self.envs:
            env.attr()

Then I check the environment with the intention to use it on a SAC agent.

sim_args = {
    "simTime": simTime,
    "envStepTime": stepTime,
    "historyLength": history_length,
    "scenario": "basic",
    "nWifi": 5,
}
threads_no = 1
env = EnvWrapperSB2(threads_no, **sim_args)

from stable_baselines.common.env_checker import check_env

# If the environment don't follow the interface, an error will be thrown
check_env(env, warn=True)

AssertionError: The observation returned by the `reset()` method does not match the given observation space

model = SAC(MlpPolicy, env, verbose=1)

System Info

Can someone please help me? perhaps I'm doing something wrong