tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.8k stars 721 forks source link

Can we change tf.compat.v1.where to tf.compat.v2.where in EpsilonGreedyPolicy._action #819

Closed JustinACoder closed 1 year ago

JustinACoder commented 1 year ago

It seems like this issue is already known as it is described in a TODO in the _action method of EpsilonGreedyPolicy (tf_agents/policies/epsilon_greedy_policy.py) :

tf.compat.v1.where only supports a condition which is either a scalar or a vector. Use tf.compat.v2 so that it can support any condition whose leading dimensions are the same as the other operands of tf.where

  def _action(self, time_step, policy_state, seed):
    seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy')
    greedy_action = self._greedy_policy.action(time_step, policy_state)
    random_action = self._random_policy.action(time_step, (), seed_stream())

    outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec)
    rng = tf.random.uniform(
        outer_shape, maxval=1.0, seed=seed_stream(), name='epsilon_rng')
    cond = tf.greater_equal(rng, self._get_epsilon())

    # Selects the action/info from the random policy with probability epsilon.
    # TODO(b/133175894): tf.compat.v1.where only supports a condition which is
    # either a scalar or a vector. Use tf.compat.v2 so that it can support any
    # condition whose leading dimensions are the same as the other operands of
    # tf.where.
    outer_ndims = int(outer_shape.shape[0])
    if outer_ndims >= 2:
      raise ValueError(
          'Only supports batched time steps with a single batch dimension')
    action = tf.nest.map_structure(lambda g, r: tf.compat.v1.where(cond, g, r),
                                   greedy_action.action, random_action.action)

I'm bringing this up because I've had some issues and noticed that changing tf.compat.v1.where to tf.compat.v2.where solves them.

In my case, the problem was that greedy_action.action had shape (1,) but random_action.action had shape ().

>>> random_action.action
<tf.Tensor: shape=(), dtype=int64, numpy=5>
>>> greedy_action.action
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([26], dtype=int64)>

The v1 can't handle this while the v2 can.

Did I do something wrong on my end? Can't we simply change v1 to v2?

JustinACoder commented 1 year ago

Although I still wonder why we don't change v1 to v2, I found my problem. The wrong shape in the action was due to the observation_and_action_constraint_splitter function that I didn't implement well.

I tried doing the observation_and_action_constraint_splitter directly in tensorflow instead of defining the legal moves in my env and then simply splitting in the observation_and_action_constraint_splitter function. It was working relatively well. However, I was returning the action mask with shape () but it had to be of shape (1,) so I simply had to change the last line of my observation_and_action_constraint_splitter function from

return observation, action_mask

To

return observation, tf.expand_dims(action_mask, axis=0)