tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.76k stars 717 forks source link

SAC for discrete action (GumbelSoftmax reparameterization trick) #507

Open Kang-SungKu opened 3 years ago

Kang-SungKu commented 3 years ago

Hello, I need to make SacAgent work with discrete action, so try to implement GumbelSoftmax parameterization trick by re-defining the relevant classes. However, the calculation of agent.train(experience).loss fails after a few iterations, so would like to have some comments to address the issue. I understand that SacAgent is designed to only work with continuous action, but any inputs will be really appreciated.

Let me provide the current situation one-by-one: 1) my configuration, 2) what I have done, and 3) the error.

First, here are the versions I am working on:

Also, here are the action spec and observation spec of my environment. The environment is a simplified version of Tetris, where the observation spec is a 2D board (while the env spec is shown as 3D, self.depth is 1 at this point), and action involves simply dropping a block from the top of the board (in other words, the number of actions = board_width * the number of available blocks). Please note that I am using observation_and_action_constraint_splitter to apply an action mask, such that infeasible action is not allowed. That's the reason the observation spec contains two different fields: 'observation' and 'mask':

self._action_spec = array_spec.BoundedArraySpec(
    shape=(), dtype=np.int32, minimum=0, maximum=self.action_spec_max, name='action')

# State space (3D array): x_coord X y_coord X z_coord, range(num_moves+1)
# Calculate theoretically maximum moves per game to bound the state
self._observation_spec = {
    'observation': array_spec.BoundedArraySpec(
        shape=(self.width, self.depth, self.height), dtype=np.int32,
        minimum=0, maximum=np.floor(self.num_moves_max), name='observation'),
    'mask': array_spec.BoundedArraySpec(shape=(self.action_spec_max+1,),
        dtype=np.int32, minimum=0, maximum=1, name='mask')
}

Second, I re-defined several classes related to SacAgent, to allow discrete action. Here, I am providing the part of the classes I modified (I omitted the lines duplicated with the original class).

class SacAgentDiscrete(sac_agent.SacAgent):
    def __init__(self,
                 ...,
                 temperature: types.Float = 0.01,
                 observation_and_action_constraint_splitter: Optional[types.Splitter] = None,
                 name: Optional[Text] = None):

        (...)

        """
        Modification for Discrete Action 0 (with action mask):
        Action mask (splitter) is used as a keyword and then applied to actor policy
        """
        self.observation_and_action_constraint_splitter = observation_and_action_constraint_splitter
        policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            observation_and_action_constraint_splitter=observation_and_action_constraint_splitter,
            training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            observation_and_action_constraint_splitter=observation_and_action_constraint_splitter,
            training=True)

        (...)

        """
        Modification for Discrete Action 1:
        default target entropy is -dim(A) or -dim(A)/2
        Since action spec is 1-dim with integer values, dim is manually calculated as follows:
        """
        self.temperature = temperature
        if target_entropy is None:
            self._action_dim = action_spec.maximum - action_spec.minimum + 1
            target_entropy = -self._action_dim/2

        (...)

    def _check_action_spec(self, action_spec):
        flat_action_spec = tf.nest.flatten(action_spec)
        for spec in flat_action_spec:
            if spec.dtype.is_integer:
                # discrete action is allowed by removing exception
                print("**NOTE: This is discrete version of SAC!")

    # This function is the same as the original one, but created to let critic and actor calculate action and log_pi differently
    def critic_actions_and_log_probs(self, time_steps):
        """Get actions and corresponding log probabilities from policy."""
        # Get raw action distribution from policy, and initialize bijectors list.
        batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
        policy_state = self._train_policy.get_initial_state(batch_size)
        distribution = self._train_policy.distribution(time_steps, policy_state=policy_state)
        action_distribution = distribution.action
        print("\t\tcritic::dist:", distribution)
        print("\t\tcritic::action_dist:", action_distribution)
        print("\t\tcritic::logit_param:", action_distribution.logits_parameter())

        # Sample actions and log_pis from transformed distribution.
        actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution)
        log_pi = common.log_probability(action_distribution, actions, self.action_spec)
        return actions, log_pi

    # This function action and log_pi calculation for actor, where GumbelSoftmax distribution is used.
    def actor_actions_and_log_probs(self, time_steps):
        """Get actions and corresponding log probabilities from policy."""
        # Get raw action distribution from policy, and initialize bijectors list.
        batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
        policy_state = self._train_policy.get_initial_state(batch_size)
        distribution = self._train_policy.distribution(time_steps, policy_state=policy_state)
        action_distribution = distribution.action

        # Sample actions and log_pis from transformed distribution.
        """
        Modification for Discrete Action 2:
        Here, GumbelSoftmax is used to sample the actions and calculate log probability
        """
        dist = gumbel_softmax.GumbelSoftmax(self.temperature, logits=action_distribution.logits_parameter())

        # Here, the shape of action == [batch, num_action], which is not compatible with action spec
        action = dist.sample()

        # Here, argmax is used to reduce the action from [batch, num_action] into [batch]
        actions = tf.math.argmax(action, axis=1, output_type=tf.int32)
        return actions, log_pi

    def critic_loss(self, ...):
        (...)
            # use _actor_actions_and_log_probs specified for critic
            next_actions, next_log_pis = self.critic_actions_and_log_probs(next_time_steps)
        (...)
            return critic_loss

    def actor_loss(self, ...):
        (...)
            # use _actor_actions_and_log_probs specified for actor
            actions, log_pi = self.actor_actions_and_log_probs(time_steps)
        (...)
            return actor_loss

# actor network is re-defined to allow discrete action
class ActorNetworkDiscrete(actor_distribution_network.ActorDistributionNetwork):
    def __init__(self,
                 input_tensor_spec,
                 output_tensor_spec,
                 preprocessing_layers=None,
                 preprocessing_combiner=None,
                 conv_layer_params=None,
                 fc_layer_params=(200, 100),
                 dropout_layer_params=None,
                 activation_fn=tf.keras.activations.relu,
                 kernel_initializer=None,
                 batch_squash=True,
                 dtype=tf.float32,
                 discrete_projection_net=_categorical_projection_net,
                 continuous_projection_net=_normal_projection_net,
                 name='ActorDistributionNetwork'):

        if not kernel_initializer:
            kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform()

        encoder = encoding_network.EncodingNetwork(
            input_tensor_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            conv_layer_params=conv_layer_params,
            fc_layer_params=fc_layer_params,
            dropout_layer_params=dropout_layer_params,
            activation_fn=activation_fn,
            kernel_initializer=kernel_initializer,
            batch_squash=batch_squash,
            dtype=dtype)

        def map_proj(spec):
            if tensor_spec.is_discrete(spec):
                return discrete_projection_net(spec)
            else:
                return continuous_projection_net(spec)

        projection_networks = tf.nest.map_structure(map_proj, output_tensor_spec)
        output_spec = tf.nest.map_structure(lambda proj_net: proj_net.output_spec,
                                            projection_networks)

        super(ActorDistributionNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            output_spec=output_spec,
            name=name)

        self._encoder = encoder
        self._projection_networks = projection_networks
        self._output_tensor_spec = output_tensor_spec

Lastly, here is the error I got from agent.train(experience).loss. It seems like the loss value becomes inf or NaN, so I believe there are minor issues that need to be fixed.

  File "C:/Users/SungKu Kang/Northeastern University/ABLE_Lab - S_tetris/Workspace/Tetris_TO/masonry_rl_agent_SAC.py", line 802, in main
    train_loss = agent.train(experience).loss
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\def_function.py", line 840, in _call
    return self._stateless_fn(*args, **kwds)
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\function.py", line 1843, in _filtered_call
    return self._call_flat(
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\function.py", line 1923, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\function.py", line 545, in call
    outputs = execute.execute(
  File "C:\Users\SungKu Kang\miniconda3\envs\TetrisMason\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Actor loss is inf or nan. : Tensor had Inf values
     [[node CheckNumerics_1 (defined at C:/Users/SungKu Kang/Northeastern University/ABLE_Lab - S_tetris/Workspace/Tetris_TO/masonry_rl_agent_SAC.py:343) ]] [Op:__inference_train_11285]
Errors may have originated from an input operation.
Input Source operations connected to node CheckNumerics_1:
 mul_1 (defined at C:/Users/SungKu Kang/Northeastern University/ABLE_Lab - S_tetris/Workspace/Tetris_TO/masonry_rl_agent_SAC.py:339)    
 CheckNumerics (defined at C:/Users/SungKu Kang/Northeastern University/ABLE_Lab - S_tetris/Workspace/Tetris_TO/masonry_rl_agent_SAC.py:330)
Function call stack:
train

I tried my best to provide the necessary details, but if any additional information is needed, please let me know. Extending the reparameterization trick to other methods (e.g., DDPG, TD3) will be also helpful. I really appreciate your time!

sguada commented 3 years ago

SAC is not designed for discrete actions, maybe it would better to try PPO.

cosmir17 commented 3 years ago

Can the following paper be considered? It suggests such possibilities.. https://arxiv.org/abs/1910.07207

The author also publicised his code. https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/blob/master/agents/actor_critic_agents/SAC_Discrete.py

I don't know if they are feasible.

DBraun commented 3 years ago

Can the following paper be considered? It suggests such possibilities.. https://arxiv.org/abs/1910.07207

The author also publicised his code. https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/blob/master/agents/actor_critic_agents/SAC_Discrete.py

I don't know if they are feasible.

A more general approach that handles continuous and discrete outputs is described here https://arxiv.org/abs/1912.11077v1

... when there are only discrete actions, our approach is equivalent to the one proposed concurrently by (Christodoulou 2019).

I'm trying to get it working by modifying the SacAgent and the actor network.

cosmir17 commented 3 years ago

This is awesome 👍 @DBraun

DBraun commented 3 years ago

I think it's working! I followed this instruction from the paper (also look at Figure 1).

If the continuous action a_c must depend on the discrete action chosen by the agent, then a_d can be used as input when computing mu_c and sigma_c.

I interpreted this as an instruction to modify the actor network. In my experiment I modified the actor_distribution_rnn_network.ActorDistributionRnnNetwork. In its call function:

  1. Run the core network (such as an RNN/LSTMs)
  2. Pass the output of the core network to the projection networks that lead to discrete actions (Categorical in tf-agents)
  3. Convert the outputs of the categorical projection networks to one hot vectors. The categorical projection networks have logits. If in training, the logits will go through a softmax. If not in training, the logits go through tf.argmax and tf.one_hot to become one hot vectors. So the result here is either soft maxed logits or one hot vectors. Both would be the same size, and they still represent discrete actions.
  4. Concatenate that to the outputs of the core actor network.
  5. Run that through the projection networks that lead to continuous actions. These values represent continuous actions.
  6. Merge the discrete actions (step 3) and continuous actions (step 5).

Does this sound correct? I didn't need to modify SacAgent much, just _check_action_spec. I made similar changes to critic rnn and actor rnn.

Should I try to clean up the code for a PR? Should I try to subclass ActorDistributionRnnNetwork/SacAgent with new classes or just copy a lot of code over? Or should I try to modify them in-place so that they become more flexible and allow both discrete and continuous actions?

Fabien-Couthouis commented 3 years ago

Hey @DBraun, good work!

Do you still plan to make a PR? It would be awesome to have a discrete version of SAC! I would be very interested by your code even if you did not clean it.

DBraun commented 3 years ago

@cosmir17 @Fabien-Couthouis I've finally made a PR. You're welcome to try out the code. It's probably pretty close to working.

Fabien-Couthouis commented 3 years ago

Hey @DBraun, thanks for the code! What do you mean by "pretty close to working"? Did you manage to make it converge on some environments?

I personnaly ended up implementing a Discrete Sac agent, based on the continuous version of tf-agents. Note that I changed the loss according to Delalleau et al., and the default target entropy by what is suggested by Christodoulou et al.. It converges on Cartpole and LunarLander-v2 with some tuning on target_update_period/target_update_tau hyperparameters.

Here is my code in case someone would be interested by a discrete SAC agent:

# discrete_sac_agent.py
"""
A Soft Actor-Critic Agent.

Implements the discrete version of Soft Actor-Critic (SAC) algorithm based on
"Discrete and Continuous Action Representation for Practical RL in Video Games" by Olivier Delalleau, Maxim Peter, Eloi Alonso, Adrien Logut (2020).
Paper: https://montreal.ubisoft.com/en/discrete-and-continuous-action-representation-for-practical-reinforcement-learning-in-video-games/
"""
# Using Type Annotations.
from __future__ import absolute_import, division, print_function

import collections
from typing import Callable, Optional, Text

import gin
import numpy as np
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
import tensorflow_probability as tfp
from six.moves import zip
from tf_agents.agents import data_converter, tf_agent
from tf_agents.networks import network
from tf_agents.policies import actor_policy, tf_policy
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common, eager_utils, nest_utils, object_identity

SacLossInfo = collections.namedtuple(
    'SacLossInfo', ('critic_loss', 'actor_loss', 'alpha_loss'))

@gin.configurable
class DiscreteSacAgent(tf_agent.TFAgent):
    """A SAC Agent that supports discrete action spaces."""

    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 critic_network: network.Network,
                 actor_network: network.Network,
                 actor_optimizer: types.Optimizer,
                 critic_optimizer: types.Optimizer,
                 alpha_optimizer: types.Optimizer,
                 actor_loss_weight: types.Float = 1.0,
                 critic_loss_weight: types.Float = 0.5,
                 alpha_loss_weight: types.Float = 1.0,
                 actor_policy_ctor: Callable[
                     ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
                 critic_network_2: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_critic_network_2: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
                 gamma: types.Float = 1.0,
                 reward_scale_factor: types.Float = 1.0,
                 initial_log_alpha: types.Float = 0.0,
                 use_log_alpha_in_alpha_loss: bool = True,
                 target_entropy: Optional[types.Float] = None,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):
        """Creates a SAC Agent.

        Args:
          time_step_spec: A `TimeStep` spec of the expected time_steps.
          action_spec: A nest of BoundedTensorSpec representing the actions.
          critic_network: A function critic_network((observations, actions)) that
            returns the q_values for each observation and action.
          actor_network: A function actor_network(observation, action_spec) that
            returns action distribution.
          actor_optimizer: The optimizer to use for the actor network.
          critic_optimizer: The default optimizer to use for the critic network.
          alpha_optimizer: The default optimizer to use for the alpha variable.
          actor_loss_weight: The weight on actor loss.
          critic_loss_weight: The weight on critic loss.
          alpha_loss_weight: The weight on alpha loss.
          actor_policy_ctor: The policy class to use.
          critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
            the second critic network during Q learning.  The weights from
            `critic_network` are copied if this is not provided.
          target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
            used as the target critic network during Q learning. Every
            `target_update_period` train steps, the weights from `critic_network`
            are copied (possibly withsmoothing via `target_update_tau`) to `
            target_critic_network`.  If `target_critic_network` is not provided, it
            is created by making a copy of `critic_network`, which initializes a new
            network with the same structure and its own layers and weights.
            Performing a `Network.copy` does not work when the network instance
            already has trainable parameters (e.g., has already been built, or when
            the network is sharing layers with another).  In these cases, it is up
            to you to build a copy having weights that are not shared with the
            original `critic_network`, so that this can be used as a target network.
            If you provide a `target_critic_network` that shares any weights with
            `critic_network`, a warning will be logged but no exception is thrown.
          target_critic_network_2: (Optional.) Similar network as
            target_critic_network but for the critic_network_2. See documentation
            for target_critic_network. Will only be used if 'critic_network_2' is
            also specified.
          target_update_tau: Factor for soft update of the target networks.
          target_update_period: Period for soft update of the target networks.
          td_errors_loss_fn:  A function for computing the elementwise TD errors
            loss.
          gamma: A discount factor for future rewards.
          reward_scale_factor: Multiplicative scale for the reward.
          initial_log_alpha: Initial value for log_alpha.
          use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha
            in alpha loss. Certain implementations of SAC use log_alpha as log
            values are generally nicer to work with.
          target_entropy: The target average policy entropy, for updating alpha. The
            default value is negative of the total number of actions.
          gradient_clipping: Norm length to clip gradients.
          debug_summaries: A bool to gather debug summaries.
          summarize_grads_and_vars: If True, gradient and network variable summaries
            will be written during training.
          train_step_counter: An optional counter to increment every time the train
            op is run.  Defaults to the global_step.
          name: The name of this agent. All variables in this module will fall under
            that name. Defaults to the class name.
        """
        flat_action_spec = tf.nest.flatten(action_spec)
        self._num_actions = np.sum([
            single_spec.maximum-single_spec.minimum+1
            for single_spec in flat_action_spec
        ])

        self._check_action_spec(action_spec)

        self._critic_network_1 = critic_network
        self._critic_network_1.create_variables(
            time_step_spec.observation)
        if target_critic_network:
            target_critic_network.create_variables(
                time_step_spec.observation)
        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(self._critic_network_1,
                                                         target_critic_network,
                                                         'TargetCriticNetwork1'))

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None
        self._critic_network_2.create_variables(
            time_step_spec.observation)
        if target_critic_network_2:
            target_critic_network_2.create_variables(
                time_step_spec.observation)
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(self._critic_network_2,
                                                         target_critic_network_2,
                                                         'TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables(time_step_spec.observation)
        self._actor_network = actor_network

        policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        train_sequence_length = 2 if not critic_network.state_spec else None

        super().__init__(
            time_step_spec,
            action_spec,
            policy=policy,
            collect_policy=policy,
            train_sequence_length=train_sequence_length,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step_counter,
            validate_args=False
        )

        self._as_transition = data_converter.AsTransition(
            self.data_context, squeeze_time_dim=(train_sequence_length == 2))

    def _check_action_spec(self, action_spec):
        flat_action_spec = tf.nest.flatten(action_spec)
        for spec in flat_action_spec:
            if spec.dtype.is_floating:
                raise NotImplementedError(
                    'DiscreteSacAgent does not support continuous actions. '
                    'Action spec: {}'.format(action_spec))

    def _get_default_target_entropy(self, action_spec):
        # Target entropy is -log(1/|A|) * ratio (= maximum entropy * ratio).
        # ratio=0.98 is thevalue used by Christodoulou, 2019 so we use this by default
        target_entropy = - np.log(1/self._num_actions) * 0.98
        return target_entropy

    def _actions_dist(self, time_steps):
        """Get actions distributions from policy."""
        # Get raw action distribution from policy, and initialize bijectors list.
        batch_size = nest_utils.get_outer_shape(
            time_steps, self._time_step_spec)[0]
        policy_state = self._train_policy.get_initial_state(batch_size)
        action_distribution = self._train_policy.distribution(
            time_steps, policy_state=policy_state).action

        return action_distribution

    def _initialize(self):
        """Returns an op to initialize the agent.

        Copies weights from the Q networks to the target Q network.
        """
        common.soft_variables_update(
            self._critic_network_1.variables,
            self._target_critic_network_1.variables,
            tau=1.0)
        common.soft_variables_update(
            self._critic_network_2.variables,
            self._target_critic_network_2.variables,
            tau=1.0)

    def _train(self, experience, weights):
        """Returns a train op to update the agent's networks.

        This method trains with the provided batched experience.

        Args:
          experience: A time-stacked trajectory object.
          weights: Optional scalar or elementwise (per-batch-entry) importance
            weights.

        Returns:
          A train_op.

        Raises:
          ValueError: If optimizers are None and no default value was provided to
            the constructor.
        """
        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        trainable_critic_variables = list(object_identity.ObjectIdentitySet(
            self._critic_network_1.trainable_variables +
            self._critic_network_2.trainable_variables))

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, ('No trainable critic variables to '
                                                'optimize.')
            tape.watch(trainable_critic_variables)
            critic_loss = self._critic_loss_weight*self.critic_loss(
                time_steps,
                actions,
                next_time_steps,
                td_errors_loss_fn=self._td_errors_loss_fn,
                gamma=self._gamma,
                reward_scale_factor=self._reward_scale_factor,
                weights=weights,
                training=True)

        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        trainable_actor_variables = self._actor_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_actor_variables, ('No trainable actor variables to '
                                               'optimize.')
            tape.watch(trainable_actor_variables)
            actor_loss = self._actor_loss_weight*self.actor_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
        actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
        self._apply_gradients(actor_grads, trainable_actor_variables,
                              self._actor_optimizer)

        alpha_variable = [self._log_alpha]
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert alpha_variable, 'No alpha variable to optimize.'
            tape.watch(alpha_variable)
            alpha_loss = self._alpha_loss_weight*self.alpha_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
        alpha_grads = tape.gradient(alpha_loss, alpha_variable)
        self._apply_gradients(alpha_grads, alpha_variable,
                              self._alpha_optimizer)

        total_loss = critic_loss + actor_loss + alpha_loss

        with tf.name_scope('Losses'):
            tf.compat.v2.summary.scalar(
                name='critic_loss', data=critic_loss, step=self.train_step_counter)
            tf.compat.v2.summary.scalar(
                name='actor_loss', data=actor_loss, step=self.train_step_counter)
            tf.compat.v2.summary.scalar(
                name='alpha_loss', data=alpha_loss, step=self.train_step_counter)
            tf.compat.v2.summary.scalar(
                name='total_loss', data=total_loss, step=self.train_step_counter)

        self.train_step_counter.assign_add(1)
        self._update_target()

        extra = SacLossInfo(
            critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss)

        return tf_agent.LossInfo(loss=total_loss, extra=extra)

    def _apply_gradients(self, gradients, variables, optimizer):
        # list(...) is required for Python3.
        grads_and_vars = list(zip(gradients, variables))
        if self._gradient_clipping is not None:
            grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,
                                                             self._gradient_clipping)

        if self._summarize_grads_and_vars:
            eager_utils.add_variables_summaries(grads_and_vars,
                                                self.train_step_counter)
            eager_utils.add_gradients_summaries(grads_and_vars,
                                                self.train_step_counter)

        optimizer.apply_gradients(grads_and_vars)

    def _get_target_updater(self, tau=1.0, period=1):
        """Performs a soft update of the target network parameters.

        For each weight w_s in the original network, and its corresponding
        weight w_t in the target network, a soft update is:
        w_t = (1- tau) x w_t + tau x ws

        Args:
          tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
          period: Step interval at which the target network is updated.

        Returns:
          A callable that performs a soft update of the target network parameters.
        """
        with tf.name_scope('update_target'):

            def update():
                """Update target network."""
                critic_update_1 = common.soft_variables_update(
                    self._critic_network_1.variables,
                    self._target_critic_network_1.variables,
                    tau,
                    tau_non_trainable=1.0)

                critic_2_update_vars = common.deduped_network_variables(
                    self._critic_network_2, self._critic_network_1)

                target_critic_2_update_vars = common.deduped_network_variables(
                    self._target_critic_network_2, self._target_critic_network_1)

                critic_update_2 = common.soft_variables_update(
                    critic_2_update_vars,
                    target_critic_2_update_vars,
                    tau,
                    tau_non_trainable=1.0)

                return tf.group(critic_update_1, critic_update_2)

            return common.Periodically(update, period, 'update_targets')

    def critic_loss(self,
                    time_steps: ts.TimeStep,
                    actions: types.Tensor,
                    next_time_steps: ts.TimeStep,
                    td_errors_loss_fn: types.LossFn,
                    gamma: types.Float = 1.0,
                    reward_scale_factor: types.Float = 1.0,
                    weights: Optional[types.Tensor] = None,
                    training: bool = False) -> types.Tensor:
        """Computes the critic loss for SAC training.

        Args:
            time_steps: A batch of timesteps.
            actions: A batch of actions.
            next_time_steps: A batch of next timesteps.
            td_errors_loss_fn: A function(td_targets, predictions) to compute
            elementwise (per-batch-entry) loss.
            gamma: Discount for future rewards.
            reward_scale_factor: Multiplicative factor to scale rewards.
            weights: Optional scalar or elementwise (per-batch-entry) importance
            weights.
            training: Whether this loss is being used for training.

        Returns:
            critic_loss: A scalar critic loss.
        """
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(
                time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(
                next_time_steps, self.time_step_spec)

            alpha = tf.math.exp(self._log_alpha)
            next_dist = self._actions_dist(next_time_steps)

            target_q_values1, _ = self._target_critic_network_1(
                next_time_steps.observation, next_time_steps.step_type, training=False)
            target_q_values2, _ = self._target_critic_network_2(
                next_time_steps.observation, next_time_steps.step_type, training=False)
            v_approx_next_state = tf.minimum(
                target_q_values1, target_q_values2)

            next_probs = next_dist.probs_parameter()
            v_approx_next_state = tf.reduce_sum(
                v_approx_next_state * next_probs, axis=-1)  # (?, )
            v_approx_next_state += alpha * next_dist.entropy()  # (?, )

            discounts = next_time_steps.discount * \
                tf.constant(gamma, dtype=tf.float32)

            # Mask is 0.0 at end of each episode to restart cumulative sum
            #   end of each episode.
            episode_mask = common.get_episode_mask(next_time_steps)

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward + discounts * v_approx_next_state*episode_mask)  # (?, 1)

            pred_td_targets1, _ = self._critic_network_1(
                time_steps.observation, time_steps.step_type, training=training)
            pred_td_targets2, _ = self._critic_network_2(
                time_steps.observation, time_steps.step_type, training=training)

            # Actually selected Q-values (from the actions batch).
            temp_one_hot = tf.one_hot(actions, depth=self._num_actions,
                                      dtype=tf.float32)  # (?, nb_actions)

            pred_td_targets1 = tf.reduce_sum(
                pred_td_targets1 * temp_one_hot, axis=-1)  # (?, 1)
            pred_td_targets2 = tf.reduce_sum(
                pred_td_targets2 * temp_one_hot, axis=-1)  # (?, 1)

            critic_loss1 = td_errors_loss_fn(
                td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(
                td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2  # (?, ) or (?, 1)

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(
                    critic_loss, axis=range(1, critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss

            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2)

            return critic_loss

    def actor_loss(self,
                   time_steps: ts.TimeStep,
                   weights: Optional[types.Tensor] = None) -> types.Tensor:
        """Computes the actor_loss for SAC training.

        Args:
        time_steps: A batch of timesteps.
        weights: Optional scalar or elementwise (per-batch-entry) importance
            weights.

        Returns:
        actor_loss: A scalar actor loss.
        """
        with tf.name_scope('actor_loss'):
            nest_utils.assert_same_structure(
                time_steps, self.time_step_spec)

            dist = self._actions_dist(time_steps)
            alpha = tf.exp(self._log_alpha)

            target_q_values1, _ = self._critic_network_1(
                time_steps.observation, time_steps.step_type, training=False)
            target_q_values2, _ = self._critic_network_2(
                time_steps.observation, time_steps.step_type, training=False)
            target_q_values = tf.minimum(
                target_q_values1, target_q_values2)  # (?, q_outputs)

            logits_q = tf.stop_gradient(
                target_q_values / alpha)
            dist_d_qs = tfp.distributions.Categorical(logits=logits_q)  # (?, )
            kl = tfp.distributions.kl_divergence(dist, dist_d_qs)  # (?, )
            actor_loss = alpha * kl  # (?, )

            if actor_loss.shape.rank > 1:
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(
                    actor_loss, axis=range(1, actor_loss.shape.rank))
            reg_loss = self._actor_network.losses if self._actor_network else None
            agg_loss = common.aggregate_losses(
                per_example_loss=actor_loss,
                sample_weight=weights,
                regularization_loss=reg_loss)
            actor_loss = agg_loss.total_loss
            self._actor_loss_debug_summaries(actor_loss, dist,
                                             target_q_values, time_steps)

            return actor_loss

    def alpha_loss(self,
                   time_steps: ts.TimeStep,
                   weights: Optional[types.Tensor] = None) -> types.Tensor:
        """Computes the alpha_loss for EC-SAC training (discrete actions).

        Args:
          time_steps: A batch of timesteps.
          weights: Optional scalar or elementwise (per-batch-entry) importance
            weights.

        Returns:
          alpha_d_loss: A scalar alpha loss (discrete action).
        """
        with tf.name_scope('alpha_loss'):
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            # alpha_loss = alpha * (H - H_target)
            #            = -alpha * (pi - log(pi) + H_target) equivalent to Christodoulou, 2019
            dist = self._actions_dist(time_steps)
            entropy = dist.entropy()
            entropy_diff = tf.stop_gradient(
                entropy - self._target_entropy)
            if self._use_log_alpha_in_alpha_loss:
                alpha_loss = self._log_alpha*entropy_diff
            else:
                alpha_loss = tf.exp(self._log_alpha)*entropy_diff

            if alpha_loss.shape.rank > 1:
                # Sum over the time dimension.
                alpha_loss = tf.reduce_mean(
                    alpha_loss, axis=range(1, alpha_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=alpha_loss, sample_weight=weights)
            alpha_loss = agg_loss.total_loss

            self._alpha_loss_debug_summaries(
                alpha_loss, entropy_diff)

            return alpha_loss

    def _actor_loss_debug_summaries(self, actor_loss, dist,
                                    target_q_values, time_steps):
        if self._debug_summaries:
            common.generate_tensor_summaries('actor_loss', actor_loss,
                                             self.train_step_counter)
            try:
                tf.compat.v2.summary.histogram(
                    name='actions_log_prob_discrete',
                    data=dist.logits,
                    step=self.train_step_counter)
            except ValueError:
                pass  # Guard against internal SAC variants that do not directly
                # generate actions.

            common.generate_tensor_summaries('target_q_values', target_q_values,
                                             self.train_step_counter)
            common.generate_tensor_summaries('act_mode', dist.mode(),
                                             self.train_step_counter)
            try:
                common.generate_tensor_summaries('entropy_action',
                                                 dist.entropy(),
                                                 self.train_step_counter)
            except NotImplementedError:
                pass  # Some distributions do not have an analytic entropy.

    def _alpha_loss_debug_summaries(self, alpha_loss, entropy_diff):
        if self._debug_summaries:
            common.generate_tensor_summaries(f'alpha_loss', alpha_loss,
                                             self.train_step_counter)
            common.generate_tensor_summaries(f'entropy_diff', entropy_diff,
                                             self.train_step_counter)
            tf.compat.v2.summary.scalar(
                name=f'log_alpha', data=self._log_alpha, step=self.train_step_counter)

    def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1,
                                     pred_td_targets2):
        if self._debug_summaries:
            td_errors1 = td_targets - pred_td_targets1
            td_errors2 = td_targets - pred_td_targets2
            td_errors = tf.concat([td_errors1, td_errors2], axis=0)
            common.generate_tensor_summaries('td_errors', td_errors,
                                             self.train_step_counter)
            common.generate_tensor_summaries('td_targets', td_targets,
                                             self.train_step_counter)
            common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1,
                                             self.train_step_counter)
            common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2,
                                             self.train_step_counter)

I also had to modify the critic network to allow multiple outputs for Q values (one Q-value per discrete action); it also supports EncodingNetwork:

# discrete_sac_critic_network.py
"""Sample Critic/Q network to use with discrete SAC agent."""

import gin
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.networks import encoding_network, network, utils
import numpy as np

@gin.configurable
class DiscreteSacCriticNetwork(network.Network):
    """Creates a critic network."""

    def __init__(self,
                 input_tensor_spec,
                 observation_preprocessing_layers=None,
                 observation_preprocessing_combiner=None,
                 observation_conv_layer_params=None,
                 observation_fc_layer_params=(75, 40),
                 observation_dropout_layer_params=None,
                 action_fc_layer_params=None,
                 action_dropout_layer_params=None,
                 joint_fc_layer_params=(75, 40),
                 joint_dropout_layer_params=None,
                 activation_fn=tf.nn.relu,
                 output_activation_fn=None,
                 kernel_initializer=None,
                 last_kernel_initializer=None,
                 batch_squash=True,
                 dtype=tf.float32,
                 name='CriticNetwork'):
        """Creates an instance of `CriticNetwork`.

        Args:
           input_tensor_spec: A tuple of (observation, action) each a nest of
            `tensor_spec.TensorSpec` representing the inputs.
          preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`
            representing preprocessing for the different observations.
            All of these layers must not be already built. For more details see
            the documentation of `networks.EncodingNetwork`.
          preprocessing_combiner: (Optional.) A keras layer that takes a flat list
            of tensors and combines them. Good options include
            `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`.
            This layer must not be already built. For more details see
            the documentation of `networks.EncodingNetwork`.
          observation_conv_layer_params: Optional list of convolution layer
            parameters for observations, where each item is a length-three tuple
            indicating (num_units, kernel_size, stride).
          observation_fc_layer_params: Optional list of fully connected parameters
            for observations, where each item is the number of units in the layer.
          observation_dropout_layer_params: Optional list of dropout layer
            parameters, each item is the fraction of input units to drop or a
            dictionary of parameters according to the keras.Dropout documentation.
            The additional parameter `permanent`, if set to True, allows to apply
            dropout at inference for approximated Bayesian inference. The dropout
            layers are interleaved with the fully connected layers; there is a
            dropout layer after each fully connected layer, except if the entry in
            the list is None. This list must have the same length of
            observation_fc_layer_params, or be None.
          action_fc_layer_params: Optional list of fully connected parameters for
            actions, where each item is the number of units in the layer.
          action_dropout_layer_params: Optional list of dropout layer parameters,
            each item is the fraction of input units to drop or a dictionary of
            parameters according to the keras.Dropout documentation. The additional
            parameter `permanent`, if set to True, allows to apply dropout at
            inference for approximated Bayesian inference. The dropout layers are
            interleaved with the fully connected layers; there is a dropout layer
            after each fully connected layer, except if the entry in the list is
            None. This list must have the same length of action_fc_layer_params, or
            be None.
          joint_fc_layer_params: Optional list of fully connected parameters after
            merging observations and actions, where each item is the number of units
            in the layer.
          joint_dropout_layer_params: Optional list of dropout layer parameters,
            each item is the fraction of input units to drop or a dictionary of
            parameters according to the keras.Dropout documentation. The additional
            parameter `permanent`, if set to True, allows to apply dropout at
            inference for approximated Bayesian inference. The dropout layers are
            interleaved with the fully connected layers; there is a dropout layer
            after each fully connected layer, except if the entry in the list is
            None. This list must have the same length of joint_fc_layer_params, or
            be None.
          activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
          output_activation_fn: Activation function for the last layer. This can be
            used to restrict the range of the output. For example, one can pass
            tf.keras.activations.sigmoid here to restrict the output to be bounded
            between 0 and 1.
          kernel_initializer: kernel initializer for all layers except for the value
            regression layer. If None, a VarianceScaling initializer will be used.
          last_kernel_initializer: kernel initializer for the value regression
             layer. If None, a RandomUniform initializer will be used.
          batch_squash: If True the outer_ranks of the observation are squashed into
            the batch dimension. This allow encoding networks to be used with
            observations with shape [BxTx...].
          dtype: The dtype to use by the layers.
          name: A string representing name of the network.

        Raises:
          ValueError: If `observation_spec` or `action_spec` contains more than one
            observation.
        """
        observation_spec, action_spec = input_tensor_spec

        super().__init__(
            input_tensor_spec=observation_spec,
            state_spec=(),
            name=name)

        if kernel_initializer is None:
            kernel_initializer = tf.compat.v1.keras.initializers.VarianceScaling(
                scale=1. / 3., mode='fan_in', distribution='uniform')
        if last_kernel_initializer is None:
            last_kernel_initializer = tf.keras.initializers.RandomUniform(
                minval=-0.003, maxval=0.003)

        self._observation_encoder = encoding_network.EncodingNetwork(
            observation_spec,
            preprocessing_layers=observation_preprocessing_layers,
            preprocessing_combiner=observation_preprocessing_combiner,
            conv_layer_params=observation_conv_layer_params,
            fc_layer_params=observation_fc_layer_params,
            dropout_layer_params=observation_dropout_layer_params,
            activation_fn=activation_fn,
            kernel_initializer=kernel_initializer,
            batch_squash=batch_squash,
            dtype=dtype)

        flat_action_spec = tf.nest.flatten(action_spec)
        q_output_size = np.sum([
            single_spec.maximum-single_spec.minimum+1
            for single_spec in flat_action_spec
        ])
        self._output_layer = tf.keras.layers.Dense(
            q_output_size,
            activation=output_activation_fn,
            kernel_initializer=last_kernel_initializer,
            name='value')

    def call(self, inputs, step_type=(), network_state=(), training=False):
        state, network_state = self._observation_encoder(
            inputs, step_type=step_type, network_state=network_state,
            training=training)

        q_values = self._output_layer(state)
        return q_values, network_state
DBraun commented 3 years ago

@Fabien-Couthouis Thanks for sharing your code. I haven't tested the standard gym environments. I'm using my code in a custom environment where the output continuous actions are a simple array and the output discrete actions are also a simple array. So things are one-dimensional, ignoring the batch size. If this isn't the case there are issues with my code because the hidden state gets flattened to this 1d and concatenated with the flattened discrete actions.

I ran https://github.com/tensorflow/agents/blob/v0.7.1/tf_agents/networks/actor_distribution_rnn_network_test.py and it failed here, possibly because the technique of concatenating the discrete actions to the hidden space before the continuous projection changes the size of the continuous projection variables. This is probably a good argument for not changing the existing classes and instead forking them into new ones.

tagomatech commented 1 year ago

Hi. What is the status of Discrete SAC in tf-agents as of today, please? Has @DBraun's or @Fabien-Couthouis's work been added to tf-agents library? (If not why?) Thank you.

fede72bari commented 1 year ago

nobody answer