hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

Param net_arch in custom DDPG policies #174

Closed PedroLucas closed 5 years ago

PedroLucas commented 5 years ago

The DDPG page in the documentation says that the actor-critic custom networks can be defined with net_arch, analogously to the common policies. However, the implementation of FeedFowardPolicy does not include net_arch as a parameter, what results in an error when the example code in the website is executed.

ValueError: Unknown keywords for policy: {'net_arch': [{'pi': [128, 128, 128], 'vf': [128, 128, 128]}]}

araffin commented 5 years ago

Hello, Oh, that's true, the doc need to be updated (DQN as well...).

Feel free to open a PR for that ;) EDIT: the right parameter is layers for those policies

kevin5k commented 5 years ago

Just to highlight that the "Custom Policy Network", in Docs, still refers to 'arch' ;)

araffin commented 5 years ago

that's normal because this is true for the general case, if you read ddpg documentation, there will be a warning.

ConorZAM commented 5 years ago

Is there a way to specify different architectures for the policy and critic networks for DDPG using the layers dict?

araffin commented 5 years ago

Is there a way to specify different architectures for the policy and critic networks for DDPG using the layers dict?

Not yet, however, you can still easily define a custom policy with different architectures for the actor and the critic (cf DDPG doc). Anyway, I would recommend you to use TD3 which is the successor of DDPG.

ConorZAM commented 5 years ago

Thanks for the heads up!

AvisekNaug commented 5 years ago

Here is a simple snippet to specify different network architectures just by changing a few lines of code in the original FeedForwardPolicy in td3. Note that the CustomFFPTD3 is just the FeedForwardPolicy in TD3 policies with the parameter layers replaced by net_arch. However, I have made net_arch a dictionary here, not a list.

import gym
import tensorflow as tf
import numpy as np
from gym.spaces import Box
from stable_baselines.td3.policies import TD3Policy, FeedForwardPolicy
from stable_baselines.sac.policies import mlp
from stable_baselines.common.policies import nature_cnn
from stable_baselines import TD3
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
class CustomFFPTD3(TD3Policy):
    """
    Policy object that implements a DDPG-like actor critic, using a feed forward neural network. It is only different
    from the existing FeedForwardPolicy for TD3 in the way network architectures are defined: here we can define 
    separate architectures for actor and critic networks.

    :param sess: (TensorFlow session) The current TensorFlow session
    :param ob_space: (Gym Space) The observation space of the environment
    :param ac_space: (Gym Space) The action space of the environment
    :param n_env: (int) The number of environments to run
    :param n_steps: (int) The number of steps to run for each environment
    :param n_batch: (int) The number of batch to run (n_envs * n_steps)
    :param reuse: (bool) If the policy is reusable or not
    :param net_arch: (dict) The architecture e of the actor and critic network for the policy (if None, default to [64, 64])
    :param cnn_extractor: (function (TensorFlow Tensor, ``**kwargs``): (TensorFlow Tensor)) the CNN feature extraction
    :param feature_extraction: (str) The feature extraction type ("cnn" or "mlp")
    :param layer_norm: (bool) enable layer normalisation
    :param act_fun: (tf.func) the activation function to use in the neural network.
    :param kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction
    """

    def __init__(self, sess, ob_space, ac_space, n_env=1, n_steps=1, n_batch=None, reuse=False, net_arch=None,
                 cnn_extractor=nature_cnn, feature_extraction="cnn",
                 layer_norm=False, act_fun=tf.nn.relu, **kwargs):
        super(CustomFFPTD3, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
                                                reuse=reuse, scale=(feature_extraction == "cnn"))

        self._kwargs_check(feature_extraction, kwargs)
        self.layer_norm = layer_norm
        self.feature_extraction = feature_extraction
        self.cnn_kwargs = kwargs
        self.cnn_extractor = cnn_extractor
        self.reuse = reuse
        if net_arch is None:
            net_arch = dict(pi=[64, 64],vf=[64,64])
        self.net_arch = net_arch

        assert ('pi' in self.net_arch.keys()) & ('vf' in self.net_arch.keys()), "KeyError: 'pi' and 'vf' keywords missing"
        assert len(self.net_arch['pi']) >= 1, "Error: must have at least one hidden layer for the actor network."
        assert len(self.net_arch['vf']) >= 1, "Error: must have at least one hidden layer for the critics network."

        self.activ_fn = act_fun

    def make_actor(self, obs=None, reuse=False, scope="pi"):
        if obs is None:
            obs = self.processed_obs

        with tf.variable_scope(scope, reuse=reuse):
            if self.feature_extraction == "cnn":
                pi_h = self.cnn_extractor(obs, **self.cnn_kwargs)
            else:
                pi_h = tf.layers.flatten(obs)

            pi_h = mlp(pi_h, self.net_arch['pi'], self.activ_fn, layer_norm=self.layer_norm)

            self.policy = policy = tf.layers.dense(pi_h, self.ac_space.shape[0], activation=tf.tanh)

        return policy

    def make_critics(self, obs=None, action=None, reuse=False, scope="values_fn"):
        if obs is None:
            obs = self.processed_obs

        with tf.variable_scope(scope, reuse=reuse):
            if self.feature_extraction == "cnn":
                critics_h = self.cnn_extractor(obs, **self.cnn_kwargs)
            else:
                critics_h = tf.layers.flatten(obs)

            # Concatenate preprocessed state and action
            qf_h = tf.concat([critics_h, action], axis=-1)

            # Double Q values to reduce overestimation
            with tf.variable_scope('qf1', reuse=reuse):
                qf1_h = mlp(qf_h, self.net_arch['vf'], self.activ_fn, layer_norm=self.layer_norm)
                qf1 = tf.layers.dense(qf1_h, 1, name="qf1")

            with tf.variable_scope('qf2', reuse=reuse):
                qf2_h = mlp(qf_h, self.net_arch['vf'], self.activ_fn, layer_norm=self.layer_norm)
                qf2 = tf.layers.dense(qf2_h, 1, name="qf2")

            self.qf1 = qf1
            self.qf2 = qf2

        return self.qf1, self.qf2

    def step(self, obs, state=None, mask=None):
        return self.sess.run(self.policy, {self.obs_ph: obs})
# Custom MLP policy with two layers
class CustomTD3Policy(CustomFFPTD3):
    def __init__(self, *args, **kwargs):
        super(CustomTD3Policy, self).__init__(*args, **kwargs,
                                           net_arch = dict(pi=[16, 16],vf=[32,32]),
                                           layer_norm=False,
                                           feature_extraction="mlp")
# Create and wrap the environment
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(CustomTD3Policy, env, action_noise=action_noise, verbose=1)
# Train the agent
model.learn(total_timesteps=80000)