tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Bijector log_prob outputs nan when executing with @tf.function but not without it. #840

Open hartikainen opened 4 years ago

hartikainen commented 4 years ago

System information

Issue We have a case where we use tfp bijectors to transform a latent gaussian distribution into another conditioned on inputs (reinforcement learning observations specifically). The shift and scale parameters of the transformation are parameterized by a feedforward network that takes the observations as inputs. The output of the transformed gaussian distribution is finally passed through a Tanh bijector.

As shown below, the implementation uses two custom bijectors (ConditionalScale and ConditionalShift) to handle the transformation. The reason for this is that I've been unable to implement the same functionality cleanly with the existing Scale and Shift bijectors [1, 2].

Now, this implementation works perfectly fine when running things in without tf.function decorators, or alternatively, as demonstrated below, when running with tf.config.experimental_run_functions_eagerly(True). I don't understand well enough how tf.function does the tracing to be able exactly point out what the root cause of the error below is, but ultimately the issue is that with tf.functions, the code below consistently produces nans, whereas when disabling them things work fine. Indeed, in my original implementation, which can be found in the softlearning repo, the same setup learns fine even when running in graph mode as long as I remove the tf.function decorators from the GaussianPolicy.{actions,log_probs} methods.

My questions thus are:

Standalone code to reproduce the issue Sorry, the code is a bit verbose because I didn't really understand where the issue arises exactly and thus couldn't prune it down much more. This is a slightly stripped version of my implementation in the softlearning project.

import sys

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import dtype_util

class ConditionalScale(bijector.Bijector):
    def __init__(self,
                 dtype=tf.float32,
                 validate_args=False,
                 name='conditional_scale'):
        """Instantiates the `ConditionalScale` bijector.

        This `Bijector`'s forward operation is:

        ```none
        Y = g(X) = scale * X
    Args:
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
        super(ConditionalScale, self).__init__(
            forward_min_event_ndims=0,
            is_constant_jacobian=True,
            validate_args=validate_args,
            dtype=dtype,
            parameters=parameters,
            name=name)

def _maybe_assert_valid_scale(self, scale):
    if not self.validate_args:
        return ()
    is_non_zero = assert_util.assert_none_equal(
        scale,
        tf.zeros((), dtype=scale.dtype),
        message='Argument `scale` must be non-zero.')
    return (is_non_zero, )

def _forward(self, x, scale):
    with tf.control_dependencies(self._maybe_assert_valid_scale(scale)):
        return x * scale

def _inverse(self, y, scale):
    with tf.control_dependencies(self._maybe_assert_valid_scale(scale)):
        return y / scale

def _forward_log_det_jacobian(self, x, scale):
    with tf.control_dependencies(self._maybe_assert_valid_scale(scale)):
        return tf.math.log(tf.abs(scale))

class ConditionalShift(bijector.Bijector): """Compute Y = g(X; shift) = X + shift.

where `shift` is a numeric `Tensor`.

Example Use:

```python
shift = Shift([-1., 0., 1])
x = [1., 2, 3]
# `forward` is equivalent to:
# y = x + shift
y = shift.forward(x)  # [0., 2., 4.]
```

"""
def __init__(self,
             dtype=tf.float32,
             validate_args=False,
             name='conditional_shift'):
    """Instantiates the `ConditionalShift` bijector.

    Args:
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
        super(ConditionalShift, self).__init__(
            forward_min_event_ndims=0,
            is_constant_jacobian=True,
            dtype=dtype,
            validate_args=validate_args,
            parameters=parameters,
            name=name)

@classmethod
def _is_increasing(cls):
    return True

def _forward(self, x, shift):
    return x + shift

def _inverse(self, y, shift):
    return y - shift

def _forward_log_det_jacobian(self, x, shift):
    # is_constant_jacobian = True for this bijector, hence the
    # `log_det_jacobian` need only be specified for a single input, as this will
    # be tiled to match `event_ndims`.
    return tf.zeros((), dtype=dtype_util.base_dtype(x.dtype))

class GaussianPolicy(): def init(self, input_shape, output_shape): output_size = tf.reduce_prod(output_shape)

    input_ = tf.keras.layers.Input(input_shape)
    out = input_
    out = tf.keras.layers.Dense(10, activation='relu')(out)
    out = tf.keras.layers.Dense(10, activation='relu')(out)
    out = tf.keras.layers.Dense(output_size * 2, activation='linear')(out)

    def split_shift_and_log_scale_diag_fn(inputs):
        shift_and_log_scale_diag = inputs
        shift, log_scale_diag = tf.split(
            shift_and_log_scale_diag,
            num_or_size_splits=2,
            axis=-1)
        scale_diag = tf.exp(log_scale_diag)
        return [shift, scale_diag]

    shift, scale = tf.keras.layers.Lambda(
        split_shift_and_log_scale_diag_fn
    )(out)

    self.shift_and_scale_model = tf.keras.Model(input_, (shift, scale))

    base_distribution = tfp.distributions.MultivariateNormalDiag(
        loc=tf.zeros(output_shape), scale_diag=tf.ones(output_shape))

    raw_action_distribution = tfp.bijectors.Chain((
        ConditionalShift(name='shift'),
        ConditionalScale(name='scale'),
    ))(base_distribution)

    self.base_distribution = base_distribution
    self.raw_action_distribution = raw_action_distribution
    self.action_distribution = tfp.bijectors.Tanh()(
        raw_action_distribution)

@tf.function(experimental_relax_shapes=True)
def actions(self, observations):
    """Compute actions for given observations."""

    batch_shape = tf.shape(observations)[0]
    shifts, scales = self.shift_and_scale_model(observations)
    actions = self.action_distribution.sample(
        batch_shape,
        bijector_kwargs={'scale': {'scale': scales},
                         'shift': {'shift': shifts}})

    return actions

@tf.function(experimental_relax_shapes=True)
def log_probs(self, observations, actions):
    """Compute log probabilities of `actions` given observations."""

    shifts, scales = self.shift_and_scale_model(observations)
    log_probs = self.action_distribution.log_prob(
        actions,
        bijector_kwargs={'scale': {'scale': scales},
                         'shift': {'shift': shifts}}
    )[..., tf.newaxis]

    return log_probs

def main():

NOTE: this is using a fixed input. However, the issue is consistently reproducible with

# with "real" observations coming from the RL environment, i.e. the eager version
# of the code never fails, whereas the graph version fails randomly (yet consistently).
observations = tf.repeat([[
    0.23420376, -0.32872833, 0.03206815, 0.18556681,
    -2.0855187, 2.12688574, -3.91442398, 2.80974896,
]], 256, axis=0)  # (256, 8)

found_nans = False

for i in range(10):
    # Loop over a few different initialization of policy parameters.
    policy = GaussianPolicy(
        input_shape=observations.shape[1:], output_shape=(2, ))

    actions = policy.actions(observations)  # (256, 2)
    log_probs = policy.log_probs(observations, actions)  # (256, 1)

    if tf.reduce_any(tf.math.is_nan(log_probs)):
        found_nans = True
        break

assert not found_nans, "Failure."

print("Success.")

if name == 'main': run_eagerly = sys.argv[1].lower() == 'true' tf.config.experimental_run_functions_eagerly(run_eagerly) main()


**Other info / logs**

$ python -m tests.test_broken_actions_v2 True Success.

$ python -m tests.test_broken_actions_v2 False Traceback (most recent call last): File "/Users/hartikainen/conda/envs/softlearning-2/lib/python3.7/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/Users/hartikainen/conda/envs/softlearning-2/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_broken_actions_v2.py", line 213, in main() File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_broken_actions_v2.py", line 205, in main assert not found_nans, "Failure." AssertionError: Failure.

axch commented 4 years ago

Off the cuff (didn't look closely at your situation): One way I've seen similar things happen is that TensorFlow flushes denormal floats to zero in graph mode, but not eager mode. If your program comes near the wrong kind of numerical instability, that behavior can easily manifest as a spurious nan. Could that be happening here?

hartikainen commented 4 years ago

@axch that could indeed be an issue since the Tanh bijector, which is used to squash the outputs here, is sometimes pretty sensitive to numerical values. I'll try to see if I can get this verified. If this turns out to be the issue, do you have any ideas on how to get around it?

axch commented 4 years ago

@hartikainen Nothing super-satisfying, I'm afraid.

hartikainen commented 4 years ago

Hey @axch, thanks for the tips! I finally had some time to dive a bit deeper into this issue and managed to narrow the problem down a bit.

There indeed is a numerical issue in the tanh computation. The reason why things don't fail in the eager mode is that, somehow the bijector caching for eager mode is different from the caching in graph mode. The nans appear in the Tanh._inverse method, which is called in graph mode but not at all in the eager mode. I also notice that if sample and the corresponding log_prob(sample) methods are called within the same tf.function, then graph mode also caches things correctly.

It's actually a bit surprising that this works in eager mode but not in graph mode. I would've guessed that this would fail exactly the other way around. Is the difference in caching behavior between the two modes intentional?

Here's a new snippet that has all the softlearning-specific code removed:

import sys

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tree

SHIFTS = np.array([
    -7.24678897857666, -9.600299835205078, -3.4814586639404297, 0.9964615702629089, 7.821957111358643, -1.0383307933807373, 7.808460235595703, -0.6535720825195312, -0.7894048690795898, -0.3533932566642761, -0.18972229957580566, -1.7350220680236816, 3.385869026184082, -3.1827340126037598, -4.716894626617432, 1.1301740407943726, -7.8049211502075195
], dtype=np.float32)
SCALES = np.array([
    0.014257000759243965, 3.0556118488311768, 0.10338590294122696, 0.18965205550193787, 0.24353782832622528, 0.2770414352416992, 3.332500696182251, 0.13196967542171478, 4.393350124359131, 0.41254958510398865, 0.8798311352729797, 0.0069636269472539425, 0.05451832339167595, 0.9536689519882202, 0.9592384696006775, 0.4479171633720398, 3.489938259124756
], dtype=np.float32)

class CustomTanh(tfp.bijectors.Tanh):
    def _forward(self, *args, **kwargs):
        result = super(CustomTanh, self)._forward(*args, **kwargs)
        tf.print(tf.reduce_all(
            tf.math.is_finite(result)), 'CustomTanh._forward')
        return result

    def _inverse(self, *args, **kwargs):
        result = super(CustomTanh, self)._inverse(*args, **kwargs)
        tf.print(tf.reduce_all(
            tf.math.is_finite(result)), 'CustomTanh._inverse')
        return result

    def _forward_log_det_jacobian(self, *args, **kwargs):
        result = super(CustomTanh, self)._forward_log_det_jacobian(
            *args, **kwargs)
        tf.print(tf.reduce_all(
            tf.math.is_finite(result)), 'CustomTanh._forward_log_det_jacobian')
        return result

def main():
    tanh = CustomTanh()
    base_distribution = tfp.distributions.MultivariateNormalDiag(
        loc=SHIFTS, scale_diag=SCALES)
    distribution = tanh(base_distribution)

    @tf.function
    def get_samples():
        samples = distribution.sample((256, ))
        return samples

    @tf.function
    def get_log_probs(samples):
        log_probs = distribution.log_prob(samples)
        return log_probs

    @tf.function
    def get_samples_and_log_probs():
        samples = distribution.sample((256, ))
        log_probs = distribution.log_prob(samples)
        return samples, log_probs

    samples = get_samples()
    log_probs = get_log_probs(samples)

    samples2, log_probs2 = get_samples_and_log_probs()

    is_finites = tree.map_structure(
        lambda x: tf.reduce_all(tf.math.is_finite(x)).numpy(),
        (samples, log_probs, samples2, log_probs2))

    assert all(is_finites), is_finites

if __name__ == '__main__':
    run_eagerly = sys.argv[1].lower() == 'true'
    tf.config.experimental_run_functions_eagerly(run_eagerly)
    main()

The outputs for eager and graph mode are shown below. There are two notable things:

  1. The _inverse is never called in eager mode but is called in the graph mode.
  2. The samples and logs probs that are computed within the same tf.function are cached correctly even in graph mode.

Eager:

$ python -m tests.test_tanh_3 true
1 CustomTanh._forward
1 CustomTanh._forward_log_det_jacobian
1 CustomTanh._forward
1 CustomTanh._forward_log_det_jacobian

Non-eager:

$ python -m tests.test_tanh_3 false
1 CustomTanh._forward
0 CustomTanh._inverse
0 CustomTanh._forward_log_det_jacobian
1 CustomTanh._forward
1 CustomTanh._forward_log_det_jacobian
Traceback (most recent call last):
  File "/Users/hartikainen/conda/envs/softlearning-tf2/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/hartikainen/conda/envs/softlearning-tf2/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_tanh_3.py", line 77, in <module>
    main()
  File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_tanh_3.py", line 71, in main
    assert all(is_finites), is_finites
AssertionError: (True, False, True, True)
axch commented 4 years ago

@hartikainen The bijector cache is keyed by Tensor id, but the id effectively changes between eager and graph modes. If you sample in one tf.function and compute the log_prob in another, the cache is very likely to be defeated. The fix is what you already came up with: include the sampling and log_prob evaluation in the same tf.function.

I think it should also be ok to have nested tf.functions, like this:

@tf.function
def sample():
  return distribution.sample(10)

@tf.function
def log_prob(x):
  return distribution.log_prob(x)

@tf.function  # The outer one is key
def doit():
  return log_prob(sample())

but I'm not actually sure whether that will work correctly or not (there is some chance that the inner tf.function will add something to some Tensor that breaks the bijector cache). This possibility is mostly relevant for a library or a large program where it may be impractical to control exactly where the tf.functions are, but if that's not your case, don't worry about it.

hartikainen commented 4 years ago

Thanks @axch, I think that makes sense. In our softlearning code, we have a case where the caching does not work even if the sample and log prob are executed within the same @tf.function (see here). I wonder if there's a way for me to somehow force the caching?

Edit: Indeed, in my above example, if I change the get_samples_and_log_probs to

    @tf.function
    def get_samples_and_log_probs():
        samples = get_samples()
        log_probs = get_log_probs(samples)
        return samples, log_probs

the script still fails.

axch commented 4 years ago

Interesting. But it does work if you take out the inner tf.functions, correct? Is that the problem in your softlearning code too? I don't know where to look -- your link seems to be to a ~3000-line PR.

hartikainen commented 4 years ago

Yeah, works if I take out the inner tf.functions. Sorry, I meant to link to the specific lines where things fail. The setup is pretty much the same as above, i.e. a couple of nested tf.functions.

The call happens here: https://github.com/rail-berkeley/softlearning/blob/a187972416694b53628eb5c2c1dcd9760f0c7b62/softlearning/algorithms/sac.py#L149-L150

And the called methods are defined here: https://github.com/rail-berkeley/softlearning/blob/a187972416694b53628eb5c2c1dcd9760f0c7b62/softlearning/policies/gaussian_policy.py#L39-L68

axch commented 4 years ago

So it looks like we've unblocked your ability to get work done?

If all else failed, I was actually going to suggest defining actions_and_log_probs as I see you have there. It's not a bad API -- in a world where bijector caching hadn't worked out, TFP distributions would probably all have had a sample_and_log_prob method for exactly this reason.

hartikainen commented 4 years ago

Yep, I think we're all good now. Obviously, it would be great if nested tf.functions would preserve the caching behavior, but I'm not sure what that would involve or if that's even possible to implement. Feel free to close this issue if that feature request should be somewhere else 🙂 Thanks a lot for your help!