ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.94k stars 5.58k forks source link

[RLlib] Categorical action dist incorrectly uses tf.random.categorical #24055

Open HJasperson opened 2 years ago

HJasperson commented 2 years ago

What happened + What you expected to happen

The problem is here: https://github.com/ray-project/ray/blob/6d8d7398df4f90abd008468c5b4fb1ebfa587256/rllib/models/tf/tf_action_dist.py#L90-L91

tf.random.categorical takes in log probabilities even though the name of the input variable is 'logits'. Here, self.inputs are the logits, so masked actions (where logits[a]=0) are considered valid samples by tf. This also impacts MultiCategorical since it ultimately calls this same method.

Versions / Dependencies

ray 1.11.0 (still present in 1.12, though) python 3.9 tf 2.7 rhel 7.9

Reproduction script

# mask last 3 actions
z = tf.constant([[0.5,0.5,0.5,0,0,0]])

# current - will sample masked actions
tf.random.categorical(z,10)

# corrected - won't sample masked actions
tf.random.categorical(tf.math.log(z),10)

Issue Severity

High: It blocks me from completing my task.

sven1977 commented 2 years ago

Hey @HJasperson , thanks for filing this issue. However, I'm not sure the fix would be that simple. Imagine a NN that outputs some values vie its last linear layer (supposed to produce action logits (not log-probs!)). Some of these logits may be 0 or at least very close to zero, which does NOT mean that these actions should be masked. Even negative action logit values are common and should still lead to these actions sometimes being sampled.

Can you instead either: a) create a custom action distribution class and specify that in your model config:

config:
  model:
    custom_action_dist: [your registered string ID] (see rllib/examples/autoregressive_action_dist.py for an example on how to do so)

b) or change your model in such a way that it outputs logits for the masked actions that are close to -inf, such that the sampling step really won't pick those actions anymore.

HJasperson commented 2 years ago

What I'm suggesting is an emergency fix because right now you are giving a tf method incorrect inputs. The method expects log probabilities and you are giving it logits. Just look at the tf documentation!

You are welcome to address special cases (e.g. very small logits) through some helper method, but as it stands now, any results of a network trained via this distribution should be considered incorrect.

HJasperson commented 2 years ago

I updated the PR with tf.math.log(tf.nn.softmax(tf.where(x!=0, x, tf.float64.min))), which should address some of your concerns. Of course the padding/masking value can be modified later, but padding/masking with zeros is by far the most common

kouroshHakha commented 1 year ago

input should be logits (aka unnormalized log-probs, link to tf_doc). The implementation in RLlib seems to be correct. Can you clarify what your rllib use-case is and when it breaks with a repro script?

HJasperson commented 1 year ago

You are incorrect - self.inputs are not log probabilities.