DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.65k stars 1.65k forks source link

Implement sampling and training asynchronously using the SAC algorithm #715

Open gilzamir18 opened 2 years ago

gilzamir18 commented 2 years ago

Question

I'm trying to implement sampling and training asynchronously using the SAC algorithm. I made the attempt shown in the code below. But I always get an error because there seems to be a confusion between training and evaluation modes. The training mode (False or True) is configured in the policy. And this is shared between the train and collect_rollouts methods. Is it possible to do collect_rollouts asynchronously?

Additional context

Reference code:

_rollouts_ = queue.Queue()

def train_async(sac, max_steps):
    global _rollouts_
    iteraction = 0
    while True:
        if not _rollouts_.empty():
            rollout = _rollouts_.get()
            if rollout is not None:
                iteraction += 1
                gradient_steps = sac.gradient_steps if sac.gradient_steps >= 0 else rollout.episode_timesteps
                if gradient_steps > 0:
                    sac.train(gradient_steps, sac.batch_size)
            else:
                print("Training ending")
                break
        else:
            print("waiting for rollouts....")
            time.sleep(1)

def rollouts_async(sac, max_steps,  callback, log_interval=None):
    steps  = 0
    global _rollouts_
    while True:
        rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
        if rollout.continue_training is False:
            _rollouts_.put(None)
            break
        else:
            _rollouts_.put( rollout )
            steps += 1
            if steps >= max_steps:
                _rollouts_.put(None)
                callback.on_training_end()
                break

def learn_async(sac, total_timesteps = 1000000, callback=None, log_interval=None, tb_log_name="run", reset_num_timesteps=True):
    total_timesteps, callback = sac._setup_learn(total_timesteps, None, callback, 0, 0, None, reset_num_timesteps, tb_log_name)
    callback.on_training_start(locals(), globals())
    t1 = threading.Thread(target=rollouts_async, args=(sac, total_timesteps, callback, log_interval))
    t1.start()
    t2 = threading.Thread(target=train_async, args=(sac, total_timesteps))
    t2.start()
    t1.join()
    t2.join()

Error:

Traceback (most recent call last):
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\gilza\doc\lab\nav\NavProAI4U\scripts\sb3sacutils.py", line 134, in rollouts_async
    rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 589, in collect_rollouts
    self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 498, in _store_transition
    replay_buffer.add(
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\buffers.py", line 562, in add
    self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (256,4) into shape (4,)

Checklist

araffin commented 2 years ago

Hello, you can find a working proof of concept here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/87001ed8a40f817d46c950e283d1ca29e405ad71/utils/callbacks.py#L95

(it is not polished but it works)

gilzamir18 commented 2 years ago

Hello, you can find a working proof of concept here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/87001ed8a40f817d46c950e283d1ca29e405ad71/utils/callbacks.py#L95

(it is not polished but it works)

I didn't know you could do that with callbacks. I spent several days looking for a way to do this. It would be interesting if it was documented. I'm already using it and it's exactly what I wanted.

Thank you very much.

araffin commented 2 years ago

I didn't know you could do that with callbacks.

yes, callbacks & wrappers are quite powerful...

It would be interesting if it was documented.

Well, I didn't have time to polish it properly, I usually don't advertise experimental features... but if you do polish it, we would be happy to receive a PR that shows it as an example in the doc ;)

I'm already using it and it's exactly what I wanted.

Good to hear =)

gilzamir18 commented 2 years ago

Hello dears. Well, let's go. A long time ago, I tested treaded training and sampling with DQN using tensorflow, but I couldn't get convergence using threading. I thought it would be different now with stable baselines3 using SAC. But I ran two experiments with the same settings: the first with threads and the second without . The first does not converge. The second converges. I think there are two bigger possibilities. Either Pytorch, like Tensorflow/Keras, doesn't handle gradient descent well when using threads, or stable-baselines callback calls have some bugs. I have no idea what it is. So, for now, I'm training without threads. It's slower, but at least it converges.

Results: Training Without Threading Training With Threading

Code with Threading

import gym
from sb3sacutils import GMultiInputPolicy, ParallelTrainCallback
import numpy as np
from stable_baselines3 import SAC
from sb3utils import CombinedExtractorWithFilters
import AI4UGym
from AI4UGym import BasicAgent
from collections import deque
from functools import reduce
import sys
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.sac.policies import MlpPolicy, CnnPolicy
from learning import prepare_env
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

#th.set_num_threads(2)
#th.set_num_interop_threads(2)

def train():
    env_config = {'verbose':False}
    prepare_env(env_config)
    env = gym.make("AI4U-v0")

    policy = None
    input_type = env.configuration['input_type']
    if input_type == "dict_linear_visual":
        policy = GMultiInputPolicy
    elif input_type == "linear":
        policy = MlpPolicy
    elif input_type == "visual":
        policy = CnnPolicy

    #checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./logs/',
                                            name_prefix='rl_model')

    parallelTrainCallback = ParallelTrainCallback(gradient_steps=64, verbose=0, sleep_time = 0.0) #the same in rlzool
    filters = [[16, 3, 1, 0], [16, 2, 2, 0]]
    policy_kwargs = dict(features_extractor_class=CombinedExtractorWithFilters, 
                         features_extractor_kwargs=dict(visual_features_dim=128, filters=filters), 
                         log_std_init=-2, 
                         net_arch=[64, 64])

    model = SAC(policy, 
                env,
                learning_starts=1000,  
                policy_kwargs=policy_kwargs, 
                gradient_steps=64, 
                learning_rate=7.3e-4, 
                tau=0.002, 
                gamma=0.98, 
                train_freq=(64, "step"), 
                use_sde=True, 
                batch_size=256, 
                ent_coef='auto', 
                buffer_size=50000, 
                verbose=1,
                use_sde_at_warmup=True,
                device="cuda", 
                tensorboard_log="./saclog")

    policy = model.policy
    model.learn(total_timesteps=10000000, callback=parallelTrainCallback, log_interval=4)
    model.save("sac_ai4u")
    del model # remove to demonstrate saving and loading

Code without threading:

import gym
from sb3sacutils import GMultiInputPolicy, ParallelTrainCallback
import numpy as np
from stable_baselines3 import SAC
from sb3utils import CombinedExtractorWithFilters
import AI4UGym
from AI4UGym import BasicAgent
from collections import deque
from functools import reduce
import sys
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.sac.policies import MlpPolicy, CnnPolicy
from learning import prepare_env
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

#th.set_num_threads(2)
#th.set_num_interop_threads(2)

def train():
    env_config = {'verbose':False}
    prepare_env(env_config)
    env = gym.make("AI4U-v0")

    policy = None
    input_type = env.configuration['input_type']
    if input_type == "dict_linear_visual":
        policy = GMultiInputPolicy
    elif input_type == "linear":
        policy = MlpPolicy
    elif input_type == "visual":
        policy = CnnPolicy

    checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./logs/',
                                            name_prefix='rl_model')

    #parallelTrainCallback = ParallelTrainCallback(gradient_steps=64, verbose=0, sleep_time = 0.0)
    filters = [[16, 3, 1, 0], [16, 2, 2, 0]]
    policy_kwargs = dict(features_extractor_class=CombinedExtractorWithFilters, 
                         features_extractor_kwargs=dict(visual_features_dim=128, filters=filters), 
                         log_std_init=-2, 
                         net_arch=[64, 64])

    model = SAC(policy, 
                env,
                learning_starts=1000,  
                policy_kwargs=policy_kwargs, 
                gradient_steps=64, 
                learning_rate=7.3e-4, 
                tau=0.002, 
                gamma=0.98, 
                train_freq=(64, "step"), 
                use_sde=True, 
                batch_size=256, 
                ent_coef='auto', 
                buffer_size=50000, 
                verbose=1,
                use_sde_at_warmup=True,
                device="cuda", 
                tensorboard_log="./saclog")

    policy = model.policy
    model.learn(total_timesteps=10000000, callback=checkpoint_callback, log_interval=4)
    model.save("sac_ai4u")
    del model # remove to demonstrate saving and loading
araffin commented 2 years ago

Hello, I'm actively using the callback, the important thing to check is comparing training with the same number of gradient updates. And also comparing how long does it take to do the gradient update vs how long does it take to collect data.

I just used the callback yesterday and it worked fine:

Please note that we do not do technical support (so unless you provide a minimal example to reproduce the issue without a custom env, we won't give further answers).

araffin commented 1 year ago

the callback is now available via the rl_zoo3 package (rl_zoo3.callbacks), VecNormalize checkpoints are missing though.