Open hartikainen opened 4 years ago
Off the cuff (didn't look closely at your situation): One way I've seen similar things happen is that TensorFlow flushes denormal floats to zero in graph mode, but not eager mode. If your program comes near the wrong kind of numerical instability, that behavior can easily manifest as a spurious nan
. Could that be happening here?
@axch that could indeed be an issue since the Tanh
bijector, which is used to squash the outputs here, is sometimes pretty sensitive to numerical values. I'll try to see if I can get this verified. If this turns out to be the issue, do you have any ideas on how to get around it?
@hartikainen Nothing super-satisfying, I'm afraid.
Hey @axch, thanks for the tips! I finally had some time to dive a bit deeper into this issue and managed to narrow the problem down a bit.
There indeed is a numerical issue in the tanh computation. The reason why things don't fail in the eager mode is that, somehow the bijector caching for eager mode is different from the caching in graph mode. The nan
s appear in the Tanh._inverse
method, which is called in graph mode but not at all in the eager mode. I also notice that if sample
and the corresponding log_prob(sample)
methods are called within the same tf.function
, then graph mode also caches things correctly.
It's actually a bit surprising that this works in eager mode but not in graph mode. I would've guessed that this would fail exactly the other way around. Is the difference in caching behavior between the two modes intentional?
Here's a new snippet that has all the softlearning-specific code removed:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tree
SHIFTS = np.array([
-7.24678897857666, -9.600299835205078, -3.4814586639404297, 0.9964615702629089, 7.821957111358643, -1.0383307933807373, 7.808460235595703, -0.6535720825195312, -0.7894048690795898, -0.3533932566642761, -0.18972229957580566, -1.7350220680236816, 3.385869026184082, -3.1827340126037598, -4.716894626617432, 1.1301740407943726, -7.8049211502075195
], dtype=np.float32)
SCALES = np.array([
0.014257000759243965, 3.0556118488311768, 0.10338590294122696, 0.18965205550193787, 0.24353782832622528, 0.2770414352416992, 3.332500696182251, 0.13196967542171478, 4.393350124359131, 0.41254958510398865, 0.8798311352729797, 0.0069636269472539425, 0.05451832339167595, 0.9536689519882202, 0.9592384696006775, 0.4479171633720398, 3.489938259124756
], dtype=np.float32)
class CustomTanh(tfp.bijectors.Tanh):
def _forward(self, *args, **kwargs):
result = super(CustomTanh, self)._forward(*args, **kwargs)
tf.print(tf.reduce_all(
tf.math.is_finite(result)), 'CustomTanh._forward')
return result
def _inverse(self, *args, **kwargs):
result = super(CustomTanh, self)._inverse(*args, **kwargs)
tf.print(tf.reduce_all(
tf.math.is_finite(result)), 'CustomTanh._inverse')
return result
def _forward_log_det_jacobian(self, *args, **kwargs):
result = super(CustomTanh, self)._forward_log_det_jacobian(
*args, **kwargs)
tf.print(tf.reduce_all(
tf.math.is_finite(result)), 'CustomTanh._forward_log_det_jacobian')
return result
def main():
tanh = CustomTanh()
base_distribution = tfp.distributions.MultivariateNormalDiag(
loc=SHIFTS, scale_diag=SCALES)
distribution = tanh(base_distribution)
@tf.function
def get_samples():
samples = distribution.sample((256, ))
return samples
@tf.function
def get_log_probs(samples):
log_probs = distribution.log_prob(samples)
return log_probs
@tf.function
def get_samples_and_log_probs():
samples = distribution.sample((256, ))
log_probs = distribution.log_prob(samples)
return samples, log_probs
samples = get_samples()
log_probs = get_log_probs(samples)
samples2, log_probs2 = get_samples_and_log_probs()
is_finites = tree.map_structure(
lambda x: tf.reduce_all(tf.math.is_finite(x)).numpy(),
(samples, log_probs, samples2, log_probs2))
assert all(is_finites), is_finites
if __name__ == '__main__':
run_eagerly = sys.argv[1].lower() == 'true'
tf.config.experimental_run_functions_eagerly(run_eagerly)
main()
The outputs for eager and graph mode are shown below. There are two notable things:
_inverse
is never called in eager mode but is called in the graph mode.tf.function
are cached correctly even in graph mode.Eager:
$ python -m tests.test_tanh_3 true
1 CustomTanh._forward
1 CustomTanh._forward_log_det_jacobian
1 CustomTanh._forward
1 CustomTanh._forward_log_det_jacobian
Non-eager:
$ python -m tests.test_tanh_3 false
1 CustomTanh._forward
0 CustomTanh._inverse
0 CustomTanh._forward_log_det_jacobian
1 CustomTanh._forward
1 CustomTanh._forward_log_det_jacobian
Traceback (most recent call last):
File "/Users/hartikainen/conda/envs/softlearning-tf2/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/Users/hartikainen/conda/envs/softlearning-tf2/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_tanh_3.py", line 77, in <module>
main()
File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_tanh_3.py", line 71, in main
assert all(is_finites), is_finites
AssertionError: (True, False, True, True)
@hartikainen The bijector cache is keyed by Tensor id, but the id effectively changes between eager and graph modes. If you sample in one tf.function
and compute the log_prob in another, the cache is very likely to be defeated. The fix is what you already came up with: include the sampling and log_prob evaluation in the same tf.function
.
I think it should also be ok to have nested tf.function
s, like this:
@tf.function
def sample():
return distribution.sample(10)
@tf.function
def log_prob(x):
return distribution.log_prob(x)
@tf.function # The outer one is key
def doit():
return log_prob(sample())
but I'm not actually sure whether that will work correctly or not (there is some chance that the inner tf.function
will add something to some Tensor that breaks the bijector cache). This possibility is mostly relevant for a library or a large program where it may be impractical to control exactly where the tf.function
s are, but if that's not your case, don't worry about it.
Thanks @axch, I think that makes sense. In our softlearning code, we have a case where the caching does not work even if the sample and log prob are executed within the same @tf.function
(see here). I wonder if there's a way for me to somehow force the caching?
Edit: Indeed, in my above example, if I change the get_samples_and_log_probs
to
@tf.function
def get_samples_and_log_probs():
samples = get_samples()
log_probs = get_log_probs(samples)
return samples, log_probs
the script still fails.
Interesting. But it does work if you take out the inner tf.function
s, correct? Is that the problem in your softlearning code too? I don't know where to look -- your link seems to be to a ~3000-line PR.
Yeah, works if I take out the inner tf.function
s. Sorry, I meant to link to the specific lines where things fail. The setup is pretty much the same as above, i.e. a couple of nested tf.function
s.
The call happens here: https://github.com/rail-berkeley/softlearning/blob/a187972416694b53628eb5c2c1dcd9760f0c7b62/softlearning/algorithms/sac.py#L149-L150
And the called methods are defined here: https://github.com/rail-berkeley/softlearning/blob/a187972416694b53628eb5c2c1dcd9760f0c7b62/softlearning/policies/gaussian_policy.py#L39-L68
So it looks like we've unblocked your ability to get work done?
If all else failed, I was actually going to suggest defining actions_and_log_probs
as I see you have there. It's not a bad API -- in a world where bijector caching hadn't worked out, TFP distributions would probably all have had a sample_and_log_prob
method for exactly this reason.
Yep, I think we're all good now. Obviously, it would be great if nested tf.function
s would preserve the caching behavior, but I'm not sure what that would involve or if that's even possible to implement. Feel free to close this issue if that feature request should be somewhere else 🙂 Thanks a lot for your help!
System information
macOS Catalina 10.15.2 (19C57)
Issue We have a case where we use tfp bijectors to transform a latent gaussian distribution into another conditioned on inputs (reinforcement learning observations specifically). The shift and scale parameters of the transformation are parameterized by a feedforward network that takes the observations as inputs. The output of the transformed gaussian distribution is finally passed through a
Tanh
bijector.As shown below, the implementation uses two custom bijectors (
ConditionalScale
andConditionalShift
) to handle the transformation. The reason for this is that I've been unable to implement the same functionality cleanly with the existingScale
andShift
bijectors [1, 2].Now, this implementation works perfectly fine when running things in without
tf.function
decorators, or alternatively, as demonstrated below, when running withtf.config.experimental_run_functions_eagerly(True)
. I don't understand well enough howtf.function
does the tracing to be able exactly point out what the root cause of the error below is, but ultimately the issue is that withtf.function
s, the code below consistently producesnan
s, whereas when disabling them things work fine. Indeed, in my original implementation, which can be found in the softlearning repo, the same setup learns fine even when running in graph mode as long as I remove thetf.function
decorators from theGaussianPolicy.{actions,log_probs}
methods.My questions thus are:
tf.function
decorators? Currently, it seems like a bug to be, but I'm unsure about that because of my lack of experience with how the tfp objects work together withtf.function
s.GaussianPolicy.{actions,log_probs}
inside atf.function
, but maybe that's wrong and there's actually a better way to achieve similar effect? Previously, ourGaussianPolicy
was implemented fully withtf.keras.Model
s which turned out to be super messy.self.shift_and_scale_model
orself.action_distribution
inGaussianPolicy.{actions,log_probs}
get traced incorrectly withing thetf.function
.Standalone code to reproduce the issue Sorry, the code is a bit verbose because I didn't really understand where the issue arises exactly and thus couldn't prune it down much more. This is a slightly stripped version of my implementation in the softlearning project.
class ConditionalShift(bijector.Bijector): """Compute
Y = g(X; shift) = X + shift
.class GaussianPolicy(): def init(self, input_shape, output_shape): output_size = tf.reduce_prod(output_shape)
def main():
NOTE: this is using a fixed input. However, the issue is consistently reproducible with
if name == 'main': run_eagerly = sys.argv[1].lower() == 'true' tf.config.experimental_run_functions_eagerly(run_eagerly) main()
$ python -m tests.test_broken_actions_v2 True Success.
$ python -m tests.test_broken_actions_v2 False Traceback (most recent call last): File "/Users/hartikainen/conda/envs/softlearning-2/lib/python3.7/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/Users/hartikainen/conda/envs/softlearning-2/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_broken_actions_v2.py", line 213, in
main()
File "/Users/hartikainen/github/rail-berkeley/softlearning-2/tests/test_broken_actions_v2.py", line 205, in main
assert not found_nans, "Failure."
AssertionError: Failure.