DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.35k stars 1.6k forks source link

[Question] Running Multi-threaded PPO training independently with no interference #1931

Open n-kish opened 1 month ago

n-kish commented 1 month ago

❓ Question

I am trying to parallelise execution of PPO training on MuJoCo environments, where each multiprocessing thread uses a slightly modified xml file to train PPO with. For this, I currently use:

import multiprocessing

num_processes = min(200, multiprocessing.cpu_count())
with multiprocessing.Pool(processes=num_processes) as pool:
        eprewmeans = pool.map(simulate_robot, xml_robots)

Here the simulate_robot function fires up the same python file with args as an xml_robot from xml_robots. This python file (train_ppo.py) looks like this currently:

import os
import argparse
import json
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure

def main():

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env_id', help='environment ID', default=None)
    parser.add_argument('--total_timesteps', help='maximum step size', type=int)
    parser.add_argument('--network', help='path for data')
    parser.add_argument('--xml_file_path', help='path for xml')
    parser.add_argument('--perf_log_path', help='path for xml')
    parser.add_argument('--ctrl_cost_weight', help='ctrl cost weight for gym env')
    args = parser.parse_args()

    # Logger Initialisation
    config_name = args.perf_log_path
    tmp_path = config_name
    new_logger = configure(tmp_path, ["stdout", "csv", "tensorboard"])

    # Create Mujoco environment
    env = gym.make(args.env_id, xml_file=args.xml_file_path)

    # Instantiate the model
    model = PPO("MlpPolicy", env, verbose=1)
    model.set_logger(new_logger)

   # Train the model
   model.learn(total_timesteps=int(args.total_timesteps))

    # Evaluate the policy
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)

if __name__ == '__main__':
    main()

When I ran this code, I found that when I have more than one process, the optimizer call in stable-baselines/ppo/ppo.py takes longer and longer as the number of processes increases. I have ensured there is no cross-play in any other parts of the code, except for some interesting time delays in the code block below.

stable-baselines3/stable_baselines3/ppo/ppo.py Lines 278 to 282

self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

Can you please help me understand what I might be dealing with here and a possible solution or alternate path to achieve the desired multiprocessing capability with SB3 ?

Checklist

araffin commented 1 month ago

Hello, if you want to parallelize gradient steps, you need to have a look at https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/11 and linked issues. if you want to parallelize data collection, you need to use VecEnv.

I'm also not sure why you have some TF code in there...

n-kish commented 1 month ago

Hi @araffin, thanks for your response.

I may have not been quite clear about what I wanted to achieve.

I am parallelizing external to the model instance and not within it. That is, with each different XML file, I run a train_ppo.py instance per thread. Depending on the XML file count, the number of parallel processes increases, and thus gym envs. are created and PPO models are trained in parallel. It is here I face the problem.

I notice that the gradient steps are somehow taking longer and longer to process with the increase in the thread count (i.e. increase in independent gym env count), which normally shouldn't be the case because each thread must be treated independently and I should have models trained independently. (Please notice the time_elapsed in seconds for just 4000 env. steps in the attached screenshot)

Screenshot from 2024-05-23 13-25-58

Hence your suggestions about parallel gradient steps and data collection through Stable-Baselines-Team/stable-baselines3-contrib/issues/11 and VecEnv, though useful, aren't addressing my problem, because I still have num_cpu=1 even in VecEnv.

Hope this clarifies things further. Please let me know how I may go about this problem, thanks.

And yes, the TF code is a blunder please ignore it.

n-kish commented 1 month ago

For anyone that maybe interested in this later. The problem I faced is due to the global autograd engine of Torch (as discussed here: Assumptions around Autograd and Python multi-threading.)

I solved this by calling each different run_ppo.py file as a separate bash process instead of relying on the mulitprocessing module.