tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.1k forks source link

It is not possible to save a PixelCNN model #812

Open AliKhoda opened 4 years ago

AliKhoda commented 4 years ago

Code:

imgsize=64

dist = tfp.distributions.PixelCNN(
    image_shape=(imgsize, imgsize, 3),
    num_resnet=1,
    num_hierarchies=2,
    num_filters=160,
    num_logistic_mix=5,
    dropout_p=.3,
    high=1, low=0
)

image_input = layers.Input(shape=(imgsize, imgsize, 3))
log_prob = dist.log_prob(image_input)

model = tf.keras.Model(inputs=image_input, outputs=log_prob)
model.add_loss(-tf.reduce_mean(log_prob))

model.compile(
    optimizer=tf.keras.optimizers.Adam(.0001),
    metrics=[])

model_json = model.to_json()

Error:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-12-d9399ed0ac2a> in <module>
----> 1 model_json = model.to_json()

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in to_json(self, **kwargs)
   1252         A JSON string.
   1253     """
-> 1254     model_config = self._updated_config()
   1255     return json.dumps(
   1256         model_config, default=serialization.get_json_type, **kwargs)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in _updated_config(self)
   1230     from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
   1231 
-> 1232     config = self.get_config()
   1233     model_config = {
   1234         'class_name': self.__class__.__name__,

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    916     if not self._is_graph_network:
    917       raise NotImplementedError
--> 918     return copy.deepcopy(get_network_config(self))
    919 
    920   @classmethod

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1991           filtered_inbound_nodes.append(node_data)
   1992 
-> 1993     layer_config = serialize_layer_fn(layer)
   1994     layer_config['name'] = layer.name
   1995     layer_config['inbound_nodes'] = filtered_inbound_nodes

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    196 
    197   if hasattr(instance, 'get_config'):
--> 198     config = instance.get_config()
    199     serialization_config = {}
    200     for key, item in config.items():

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in get_config(self)
    497     # or that `get_config` has been overridden:
    498     if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
--> 499       raise NotImplementedError('Layers with arguments in `__init__` must '
    500                                 'override `get_config`.')
    501     return config

NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
AliKhoda commented 4 years ago

Would adding this to "tensorflow_probability/python/distributions/pixel_cnn.py" solve the issue?

def get_config(self):

  config = super().get_config().copy()
  config.update({
      '_dropout_p': self._dropout_p,
      '_num_resnet': self._num_resnet,
      '_num_hierarchies': self._num_hierarchies,
      '_num_filters': self._num_filters,
      '_num_logistic_mix': self._num_logistic_mix,
      '_receptive_field_dims': self._receptive_field_dims,
      '_resnet_activation': self._resnet_activation
      '_layer_wrapper': self._layer_wrapper
  })
  return config
AliKhoda commented 4 years ago

It does fix the saving, but not the loading:

Code:

loaded_model = tf.keras.models.model_from_json(model_json)

Error:

ValueError                                Traceback (most recent call last)
<ipython-input-12-3120867be721> in <module>
      2     model_json = json_file.read()
      3 
----> 4 loaded_model = tf.keras.models.model_from_json(model_json)
      5 loaded_model.summary()
      6 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/model_config.py in model_from_json(json_string, custom_objects)
     94   config = json.loads(json_string)
     95   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 96   return deserialize(config, custom_objects=custom_objects)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    104       module_objects=globs,
    105       custom_objects=custom_objects,
--> 106       printable_module_name='layer')

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    301             custom_objects=dict(
    302                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 303                 list(custom_objects.items())))
    304       with CustomObjectScope(custom_objects):
    305         return cls.from_config(cls_config)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in from_config(cls, config, custom_objects)
    935     """
    936     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 937         config, custom_objects)
    938     model = cls(inputs=input_tensors, outputs=output_tensors,
    939                 name=config.get('name'))

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in reconstruct_from_config(config, custom_objects, created_layers)
   1891   # First, we create all layers and enqueue nodes to be processed
   1892   for layer_data in config['layers']:
-> 1893     process_layer(layer_data)
   1894   # Then we process nodes in order of layer depth.
   1895   # Nodes that cannot yet be processed (if the inbound node

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in process_layer(layer_data)
   1873       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1874 
-> 1875       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1876       created_layers[layer_name] = layer
   1877 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    104       module_objects=globs,
    105       custom_objects=custom_objects,
--> 106       printable_module_name='layer')

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    290     config = identifier
    291     (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 292         config, module_objects, custom_objects, printable_module_name)
    293 
    294     if hasattr(cls, 'from_config'):

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    248     cls = module_objects.get(class_name)
    249     if cls is None:
--> 250       raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    251 
    252   cls_config = config['config']

ValueError: Unknown layer: _PixelCNNNetwork
AliKhoda commented 4 years ago

Code:

loaded_model = tf.keras.models.model_from_json(model_json, custom_objects={'_PixelCNNNetwork': tfp.distributions.PixelCNN})

Error:

TypeError                                 Traceback (most recent call last)
<ipython-input-15-e742f0ee369b> in <module>
      2     model_json = json_file.read()
      3 
----> 4 loaded_model = tf.keras.models.model_from_json(model_json, custom_objects={'_PixelCNNNetwork': tfp.distributions.PixelCNN})
      5 loaded_model.summary()
      6 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/model_config.py in model_from_json(json_string, custom_objects)
     94   config = json.loads(json_string)
     95   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 96   return deserialize(config, custom_objects=custom_objects)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    104       module_objects=globs,
    105       custom_objects=custom_objects,
--> 106       printable_module_name='layer')

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    301             custom_objects=dict(
    302                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 303                 list(custom_objects.items())))
    304       with CustomObjectScope(custom_objects):
    305         return cls.from_config(cls_config)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in from_config(cls, config, custom_objects)
    936     """
    937     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 938         config, custom_objects)
    939     model = cls(inputs=input_tensors, outputs=output_tensors,
    940                 name=config.get('name'))

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in reconstruct_from_config(config, custom_objects, created_layers)
   1892   # First, we create all layers and enqueue nodes to be processed
   1893   for layer_data in config['layers']:
-> 1894     process_layer(layer_data)
   1895   # Then we process nodes in order of layer depth.
   1896   # Nodes that cannot yet be processed (if the inbound node

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in process_layer(layer_data)
   1874       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1875 
-> 1876       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1877       created_layers[layer_name] = layer
   1878 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    104       module_objects=globs,
    105       custom_objects=custom_objects,
--> 106       printable_module_name='layer')

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    310       custom_objects = custom_objects or {}
    311       with CustomObjectScope(custom_objects):
--> 312         return cls(**cls_config)
    313   elif isinstance(identifier, six.string_types):
    314     object_name = identifier

TypeError: __init__() got an unexpected keyword argument 'trainable'
AliKhoda commented 4 years ago

Modifications:

def get_config(self):

    config = super().get_config().copy()
    config.update({
      'dropout_p': self._dropout_p,
      'num_resnet': self._num_resnet,
      'num_hierarchies': self._num_hierarchies,
      'num_filters': self._num_filters,
      'num_logistic_mix': self._num_logistic_mix,
      'receptive_field_dims': self._receptive_field_dims,
      'resnet_activation': self._resnet_activation,
    })
    return config
def __init__(
      self,
      dropout_p=0.5,
      num_resnet=5,
      num_hierarchies=3,
      num_filters=160,
      num_logistic_mix=10,
      receptive_field_dims=(3, 3),
      resnet_activation='concat_elu',
      use_weight_norm=True,
      use_data_init=True,
      dtype=tf.float32,
      **kwargs):

super(_PixelCNNNetwork, self).__init__(dtype=dtype, **kwargs)

Code:

from tensorflow_probability.python.distributions.pixel_cnn import _PixelCNNNetwork

loaded_model = tf.keras.models.model_from_json(model_json, custom_objects={'_PixelCNNNetwork': _PixelCNNNetwork})

Error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-11-81e595db35bf> in <module>
      4 from tensorflow_probability.python.distributions.pixel_cnn import _PixelCNNNetwork
      5 
----> 6 loaded_model = tf.keras.models.model_from_json(model_json, custom_objects={'_PixelCNNNetwork': _PixelCNNNetwork})
      7 loaded_model.summary()
      8 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/model_config.py in model_from_json(json_string, custom_objects)
     94   config = json.loads(json_string)
     95   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 96   return deserialize(config, custom_objects=custom_objects)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    104       module_objects=globs,
    105       custom_objects=custom_objects,
--> 106       printable_module_name='layer')

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    301             custom_objects=dict(
    302                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 303                 list(custom_objects.items())))
    304       with CustomObjectScope(custom_objects):
    305         return cls.from_config(cls_config)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in from_config(cls, config, custom_objects)
    936     """
    937     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 938         config, custom_objects)
    939     model = cls(inputs=input_tensors, outputs=output_tensors,
    940                 name=config.get('name'))

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in reconstruct_from_config(config, custom_objects, created_layers)
   1902       if layer in unprocessed_nodes:
   1903         for node_data in unprocessed_nodes.pop(layer):
-> 1904           process_node(layer, node_data)
   1905 
   1906   input_tensors = []

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in process_node(layer, node_data)
   1850       if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
   1851         input_tensors = flat_input_tensors[0]
-> 1852       output_tensors = layer(input_tensors, **kwargs)
   1853 
   1854       # Update node index map.

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    746           # Build layer if applicable (if the `build` method has been
    747           # overridden).
--> 748           self._maybe_build(inputs)
    749           cast_inputs = self._maybe_cast_inputs(inputs)
    750 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
   2114         # operations.
   2115         with tf_utils.maybe_init_scope(self):
-> 2116           self.build(input_shapes)
   2117       # We must set self.built since user defined build functions are not
   2118       # constrained to set self.built.

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_probability/python/distributions/pixel_cnn.py in build(self, input_shape)
    950     inputs = (image_input if conditional_input is None
    951               else [image_input, conditional_input])
--> 952     self._network = tf.keras.Model(inputs=inputs, outputs=outputs)
    953     super(_PixelCNNNetwork, self).build(input_shape)
    954 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in __init__(self, *args, **kwargs)
    144 
    145   def __init__(self, *args, **kwargs):
--> 146     super(Model, self).__init__(*args, **kwargs)
    147     _keras_api_gauge.get_cell('model').set(True)
    148     # initializing _distribution_strategy here since it is possible to call

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in __init__(self, *args, **kwargs)
    168         'inputs' in kwargs and 'outputs' in kwargs):
    169       # Graph network
--> 170       self._init_graph_network(*args, **kwargs)
    171     else:
    172       # Subclassed network

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in _init_graph_network(self, inputs, outputs, name, **kwargs)
    271 
    272     if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
--> 273       base_layer_utils.create_keras_history(self._nested_outputs)
    274 
    275     self._base_init(name=name, **kwargs)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in create_keras_history(tensors)
    185     keras_tensors: The Tensors found that came from a Keras Layer.
    186   """
--> 187   _, created_layers = _create_keras_history_helper(tensors, set(), [])
    188   return created_layers
    189 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
    247               constants[i] = backend.function([], op_input)([])
    248       processed_ops, created_layers = _create_keras_history_helper(
--> 249           layer_inputs, processed_ops, created_layers)
    250       name = op.name
    251       node_def = op.node_def.SerializeToString()

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
    247               constants[i] = backend.function([], op_input)([])
    248       processed_ops, created_layers = _create_keras_history_helper(
--> 249           layer_inputs, processed_ops, created_layers)
    250       name = op.name
    251       node_def = op.node_def.SerializeToString()

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
    247               constants[i] = backend.function([], op_input)([])
    248       processed_ops, created_layers = _create_keras_history_helper(
--> 249           layer_inputs, processed_ops, created_layers)
    250       name = op.name
    251       node_def = op.node_def.SerializeToString()

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
    247               constants[i] = backend.function([], op_input)([])
    248       processed_ops, created_layers = _create_keras_history_helper(
--> 249           layer_inputs, processed_ops, created_layers)
    250       name = op.name
    251       node_def = op.node_def.SerializeToString()

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
    245           else:
    246             with ops.init_scope():
--> 247               constants[i] = backend.function([], op_input)([])
    248       processed_ops, created_layers = _create_keras_history_helper(
    249           layer_inputs, processed_ops, created_layers)

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py in __call__(self, inputs)
   3733     return nest.pack_sequence_as(
   3734         self._outputs_structure,
-> 3735         [x._numpy() for x in outputs],  # pylint: disable=protected-access
   3736         expand_composites=True)
   3737 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py in <listcomp>(.0)
   3733     return nest.pack_sequence_as(
   3734         self._outputs_structure,
-> 3735         [x._numpy() for x in outputs],  # pylint: disable=protected-access
   3736         expand_composites=True)
   3737 

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py in _numpy(self)
    908       return self._numpy_internal()
    909     except core._NotOkStatusException as e:
--> 910       six.raise_from(core._status_to_exception(e.code, e.message), None)
    911 
    912   @property

~/anaconda3/envs/tf2-gpu/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.
AliKhoda commented 4 years ago

Solution: Not to save the architecture, but to save only the weights.