Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
MIT License
465 stars 173 forks source link

Speed up when using MaskablePPO #205

Open vahidqo opened 1 year ago

vahidqo commented 1 year ago

❓ Question

Hi, I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training? Thank you,

class customenv(gym.Env):.... env = customenv() env = ActionMasker(env, mask_fn) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=0) model.learn(1000000)


araffin commented 1 year ago


I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?

Related: and and

You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see

vahidqo commented 11 months ago


I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?

Related: DLR-RM/stable-baselines3#1245 and DLR-RM/stable-baselines3#90 (comment) and DLR-RM/stable-baselines3#682

You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see #49 (comment)

Sorry for opening again, but I face an error when using with custom env. All my environment methods are the same as "InvalidActionEnvDiscrete":

EOFError                                  Traceback (most recent call last)
[<ipython-input-44-67cc0d019c11>](https://localhost:8080/#) in <cell line: 2>()
      1 model = MaskablePPO("MlpPolicy", env, verbose=1, tensorboard_log="/content/drive/MyDrive/Colab Notebooks/JOM/test")
----> 2 model.learn(100000)

7 frames
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/](https://localhost:8080/#) in learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, use_masking, progress_bar)
    525         while self.num_timesteps < total_timesteps:
--> 526             continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
    528             if continue_training is False:

[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/](https://localhost:8080/#) in collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps, use_masking)
    287         rollout_buffer.reset()
--> 289         if use_masking and not is_masking_supported(env):
    290             raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")

[/usr/local/lib/python3.10/dist-packages/sb3_contrib/common/maskable/](https://localhost:8080/#) in is_masking_supported(env)
     31         try:
     32             # TODO: add VecEnv.has_attr()
---> 33             env.get_attr(EXPECTED_METHOD_NAME)
     34             return True
     35         except AttributeError:

[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/](https://localhost:8080/#) in get_attr(self, attr_name, indices)
    171         for remote in target_remotes:
    172             remote.send(("get_attr", attr_name))
--> 173         return [remote.recv() for remote in target_remotes]
    175     def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:

[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/](https://localhost:8080/#) in <listcomp>(.0)
    171         for remote in target_remotes:
    172             remote.send(("get_attr", attr_name))
--> 173         return [remote.recv() for remote in target_remotes]
    175     def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:

[/usr/lib/python3.10/multiprocessing/](https://localhost:8080/#) in recv(self)
    248         self._check_closed()
    249         self._check_readable()
--> 250         buf = self._recv_bytes()
    251         return _ForkingPickler.loads(buf.getbuffer())

[/usr/lib/python3.10/multiprocessing/](https://localhost:8080/#) in _recv_bytes(self, maxsize)
    413     def _recv_bytes(self, maxsize=None):
--> 414         buf = self._recv(4)
    415         size, = struct.unpack("!i", buf.getvalue())
    416         if size == -1:

[/usr/lib/python3.10/multiprocessing/](https://localhost:8080/#) in _recv(self, size, read)
    381             if n == 0:
    382                 if remaining == size:
--> 383                     raise EOFError
    384                 else:
    385                     raise OSError("got end of file during message")
