keras-team / keras-io

Keras documentation, hosted live at keras.io
Apache License 2.0
2.69k stars 2.01k forks source link

bayesian_neural_networks failed with Mixed Precision enabled #1860

Open LifengWang opened 1 month ago

LifengWang commented 1 month ago

Issue Type

Bug

Source

binary

Keras Version

2.16.0

Custom Code

No

OS Platform and Distribution

Linux Ubuntu 20.04

Python version

3.10

GPU model and memory

No response

Current Behavior?

bayesian_neural_networks example works well with the default fp32 data type when using the legacy keras.

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

While when I enable the mixed precision for bfloat16 or float16 with the following code. bayesian_neural_networks example failed.

import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Standalone code to reproduce the issue or tutorial link

Just add the following code snippet at the beginning of this code example https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/bayesian_neural_networks.py.

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")

Relevant log output

Traceback (most recent call last):
  File "/root/lifeng/keras-io/examples/keras_recipes/bayesian_neural_networks.py", line 304, in <module>
    bnn_model_small = create_bnn_model(train_sample_size)
  File "/root/lifeng/keras-io/examples/keras_recipes/bayesian_neural_networks.py", line 274, in create_bnn_model
    features = tfp.layers.DenseVariational(
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/dense_variational_v2.py", line 123, in call
    self.add_loss(self._kl_divergence_fn(q, r))
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/dense_variational_v2.py", line 187, in _fn
    kl = kl_divergence_fn(distribution_a, distribution_b)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/dense_variational_v2.py", line 180, in kl_divergence_fn
    distribution_a.log_prob(z) - distribution_b.log_prob(z),
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1287, in log_prob
    return self._call_log_prob(value, name, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/distribution.py", line 1269, in _call_log_prob
    return self._log_prob(value, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/layers/internal/distribution_tensor_coercible.py", line 114, in _log_prob
    return self.tensor_distribution._log_prob(value, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/internal/distribution_util.py", line 1350, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/mvn_linear_operator.py", line 243, in _log_prob
    return super(MultivariateNormalLinearOperator, self)._log_prob(x)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 364, in _log_prob
    log_prob, _ = self.experimental_local_measure(
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py", line 611, in experimental_local_measure
    x = self.bijector.inverse(y, **bijector_kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py", line 1389, in inverse
    return self._call_inverse(y, name, **kwargs)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/bijectors/bijector.py", line 1362, in _call_inverse
    y = nest_util.convert_to_nested_tensor(
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/internal/nest_util.py", line 503, in convert_to_nested_tensor
    return convert_fn((), value, dtype, dtype_hint, name=name)
  File "/root/miniconda3/envs/oob/lib/python3.10/site-packages/tensorflow_probability/python/internal/nest_util.py", line 495, in convert_fn
    return tf.convert_to_tensor(value, dtype, dtype_hint, name=name)
ValueError: Exception encountered when calling layer "dense_variational" (type DenseVariational).

y: Tensor conversion requested dtype float32 for Tensor with dtype bfloat16: <tf.Tensor 'dense_variational/sequential/multivariate_normal_tri_l/tensor_coercible/value/MultivariateNormalTriL/sample/chain_of_shift_of_scale_matvec_linear_operator/forward/shift/forward/add:0' shape=(96,) dtype=bfloat16>

Call arguments received by layer "dense_variational" (type DenseVariational):
  • inputs=tf.Tensor(shape=(None, 11), dtype=bfloat16)
LifengWang commented 1 month ago

The standard neural network works well with mixed precision but the Bayesian neural network failed.

chunduriv commented 1 month ago

@LifengWang,

Thanks for reporting the issue. I have reproduced the behavior, please see gist for reference.

mattdangerw commented 1 month ago

I'm not sure if in general we expect all examples to run under mixed precision without error. That will often depend on the custom layers/models created by the author in a given example.

But if anyone has interest in diving in here and fixing this up, contributions are welcome!