tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.1k forks source link

Cannot save tfd.PixelCNN with tf.train.Checkpoint #902

Closed DrKwint closed 4 years ago

DrKwint commented 4 years ago

This is a duplicate of https://github.com/tensorflow/tensorflow/issues/38893, but I'm also submitting here because I'm unsure which repo it should be on. Apologies if this is the wrong place for it.

System information

Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04 Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: TensorFlow installed from (source or binary): source TensorFlow version (use command below): v2.1.0-rc2-17-ge5bf8de 2.1.0 Python version: 3.6.9 Bazel version (if compiling from source): GCC/Compiler version (if compiling from source): CUDA/cuDNN version: 10.1 GPU model and memory: GTX 1070 You can collect some of this information using our environment capture script You can also obtain the TensorFlow version with:

TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)" Describe the current behavior Trying to save a model which includes a tfd.PixelCNN gives the traceback:

File "test.py", line 16, in checkpoint.save(file_prefix='fails_before_here') File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py", line 1902, in save file_path = self.write("%s-%d" % (file_prefix, checkpoint_number)) File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py", line 1832, in write output = self._saver.save(file_prefix=file_prefix) File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py", line 1168, in save file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor) File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py", line 1108, in _save_cached_when_graph_building object_graph_tensor=object_graph_tensor) File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py", line 1076, in _gather_saveables feed_additions) = self._graph_view.serialize_object_graph() File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/graph_view.py", line 379, in serialize_object_graph trackable_objects, path_to_root = self._breadth_first_traversal() File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/graph_view.py", line 199, in _breadth_first_traversal for name, dependency in self.list_dependencies(current_trackable): File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/graph_view.py", line 159, in list_dependencies return obj._checkpoint_dependencies File "/home/equint/GitHub/pyroclast/env/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/data_structures.py", line 509, in _checkpoint_dependencies "automatically un-wrapped and subsequently ignored." % (self,))) ValueError: Unable to save the object ListWrapper([0, 1, 2]) (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced (setitem, setslice), deleted (delitem, delslice), or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures Describe the expected behavior Shouldn't have a problem saving a distribution using tf.train.Checkpoint.save

Standalone code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf import tensorflow_probability as tfp

tfd = tfp.distributions

model = tfd.PixelCNN( image_shape=(28, 28, 1), conditional_shape=(28, 28, 1), num_resnet=1, num_hierarchies=2, num_filters=32, num_logistic_mix=4, dropout_p=.3, ) checkpoint = tf.train.Checkpoint(model=model) checkpoint.save(file_prefix='fails_before_here') Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

emilyfertig commented 4 years ago

Thanks -- this turned out to be an issue in TFP's WeightNorm layer and is now fixed.