Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
465 stars 173 forks source link

Speed up when using MaskablePPO #205

Open vahidqo opened 1 year ago

vahidqo commented 1 year ago

❓ Question

Hi, I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training? Thank you,

class customenv(gym.Env):.... env = customenv() env = ActionMasker(env, mask_fn) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=0) model.learn(1000000)

Checklist

araffin commented 1 year ago

Hello,

I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?

Related: https://github.com/DLR-RM/stable-baselines3/issues/1245 and https://github.com/DLR-RM/stable-baselines3/issues/90#issuecomment-659607948 and https://github.com/DLR-RM/stable-baselines3/issues/682

You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49#issuecomment-1422869188

vahidqo commented 11 months ago

Hello,

I'm using MaskablePPO on a powerful computer but the speed of the training doesn't change compared to a normal computer. Is there any option or line of code that increases the speed of training?

Related: DLR-RM/stable-baselines3#1245 and DLR-RM/stable-baselines3#90 (comment) and DLR-RM/stable-baselines3#682

You should probably use multiple envs too, in that case, you should define the action mask function directly in the env, see #49 (comment)

Sorry for opening again, but I face an error when using with custom env. All my environment methods are the same as "InvalidActionEnvDiscrete":

`---------------------------------------------------------------------------
EOFError                                  Traceback (most recent call last)
[<ipython-input-44-67cc0d019c11>](https://localhost:8080/#) in <cell line: 2>()
      1 model = MaskablePPO("MlpPolicy", env, verbose=1, tensorboard_log="/content/drive/MyDrive/Colab Notebooks/JOM/test")
----> 2 model.learn(100000)

7 frames
[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/ppo_mask.py](https://localhost:8080/#) in learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, use_masking, progress_bar)
    524 
    525         while self.num_timesteps < total_timesteps:
--> 526             continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
    527 
    528             if continue_training is False:

[/usr/local/lib/python3.10/dist-packages/sb3_contrib/ppo_mask/ppo_mask.py](https://localhost:8080/#) in collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps, use_masking)
    287         rollout_buffer.reset()
    288 
--> 289         if use_masking and not is_masking_supported(env):
    290             raise ValueError("Environment does not support action masking. Consider using ActionMasker wrapper")
    291 

[/usr/local/lib/python3.10/dist-packages/sb3_contrib/common/maskable/utils.py](https://localhost:8080/#) in is_masking_supported(env)
     31         try:
     32             # TODO: add VecEnv.has_attr()
---> 33             env.get_attr(EXPECTED_METHOD_NAME)
     34             return True
     35         except AttributeError:

[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/subproc_vec_env.py](https://localhost:8080/#) in get_attr(self, attr_name, indices)
    171         for remote in target_remotes:
    172             remote.send(("get_attr", attr_name))
--> 173         return [remote.recv() for remote in target_remotes]
    174 
    175     def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:

[/usr/local/lib/python3.10/dist-packages/stable_baselines3/common/vec_env/subproc_vec_env.py](https://localhost:8080/#) in <listcomp>(.0)
    171         for remote in target_remotes:
    172             remote.send(("get_attr", attr_name))
--> 173         return [remote.recv() for remote in target_remotes]
    174 
    175     def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:

[/usr/lib/python3.10/multiprocessing/connection.py](https://localhost:8080/#) in recv(self)
    248         self._check_closed()
    249         self._check_readable()
--> 250         buf = self._recv_bytes()
    251         return _ForkingPickler.loads(buf.getbuffer())
    252 

[/usr/lib/python3.10/multiprocessing/connection.py](https://localhost:8080/#) in _recv_bytes(self, maxsize)
    412 
    413     def _recv_bytes(self, maxsize=None):
--> 414         buf = self._recv(4)
    415         size, = struct.unpack("!i", buf.getvalue())
    416         if size == -1:

[/usr/lib/python3.10/multiprocessing/connection.py](https://localhost:8080/#) in _recv(self, size, read)
    381             if n == 0:
    382                 if remaining == size:
--> 383                     raise EOFError
    384                 else:
    385                     raise OSError("got end of file during message")

EOFError:`