tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

save_weights causes ValueError when DistributionLambda is used with JointDistributionSequential #926

Open MArpogaus opened 4 years ago

MArpogaus commented 4 years ago

Hello!

I am currently trying to use JointDistributionSequential to predict multiple distributions using a Mixture Density Network.

Minimal example:

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb

from functools import partial

neurons = 32
components = 2
no_dists = 20

HiddenLayer = partial(
    tf.keras.layers.Dense,
    activation="elu",
    kernel_initializer="he_normal",
    kernel_regularizer=tf.keras.regularizers.l2(0.01)
)

OutputLayer = partial(
    tf.keras.layers.Dense,
    activation="linear",
    #kernel_regularizer=tf.keras.regularizers.l2(0.001)
)

inputs = tf.keras.layers.Input(shape=(1,))
h1 = HiddenLayer(neurons)(inputs)
h2 = HiddenLayer(neurons/2)(h1)

logits = OutputLayer(no_dists*components, name="logits")(h2)
logits_rshpd = tf.keras.layers.Reshape((no_dists,components))(logits)

locs = OutputLayer(no_dists*components, name="locs")(h2)
locs_rshpd = tf.keras.layers.Reshape((no_dists,components))(locs)

scales = OutputLayer(no_dists*components, activation='softplus', name="log_scales")(h2)
scales_rshpd = tf.keras.layers.Reshape((no_dists,components))(scales)

def joint(pvector):
    logits, locs, scales = pvector
    mixtures = []
    for d in range(no_dists):
        mixture = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=logits[:,d]),
            components_distribution=tfd.Normal(
                loc=locs[:,d],       
                scale=scales[:,d]))
        mixtures.append(mixture)

    joint = tfd.JointDistributionSequential(mixtures)
    return tfd.Blockwise(joint)

out_joint = tfp.layers.DistributionLambda(joint)(
    (
        logits_rshpd,
        locs_rshpd,
        scales_rshpd
    )
)

gmm_model = tf.keras.Model(
    inputs,
    out_joint,
    name ='mdn'
)

gmm_model.summary()

Training with the TensorFlow keras API works as expected, but when i use keras.callbacks.ModelCheckpoint or gmm_model.save_weights('test') to save the weights i get the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-3-6285a5a969b4> in <module>
----> 1 gmm_model.save_weights('test')

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py in save_weights(self, filepath, overwrite, save_format)
   1165              'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
   1166             % (optimizer,))
-> 1167       self._trackable_saver.save(filepath, session=session)
   1168       # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
   1169       checkpoint_management.update_checkpoint_state_internal(

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py in save(self, file_prefix, checkpoint_number, session)
   1185     file_io.recursive_create_dir(os.path.dirname(file_prefix))
   1186     save_path, new_feed_additions = self._save_cached_when_graph_building(
-> 1187         file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
   1188     if new_feed_additions:
   1189       feed_dict.update(new_feed_additions)

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py in _save_cached_when_graph_building(self, file_prefix, object_graph_tensor)
   1125     (named_saveable_objects, graph_proto,
   1126      feed_additions) = self._gather_saveables(
-> 1127          object_graph_tensor=object_graph_tensor)
   1128     if (self._last_save_object_graph != graph_proto
   1129         # When executing eagerly, we need to re-create SaveableObjects each time

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py in _gather_saveables(self, object_graph_tensor)
   1093     """Wraps _serialize_object_graph to include the object graph proto."""
   1094     (named_saveable_objects, graph_proto,
-> 1095      feed_additions) = self._graph_view.serialize_object_graph()
   1096     if object_graph_tensor is None:
   1097       with ops.device("/cpu:0"):

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/graph_view.py in serialize_object_graph(self)
    377       ValueError: If there are invalid characters in an optimizer's slot names.
    378     """
--> 379     trackable_objects, path_to_root = self._breadth_first_traversal()
    380     return self._serialize_gathered_objects(
    381         trackable_objects, path_to_root)

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/graph_view.py in _breadth_first_traversal(self)
    197             % (current_trackable,))
    198       bfs_sorted.append(current_trackable)
--> 199       for name, dependency in self.list_dependencies(current_trackable):
    200         if dependency not in path_to_root:
    201           path_to_root[dependency] = (

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/graph_view.py in list_dependencies(self, obj)
    157     # pylint: disable=protected-access
    158     obj._maybe_initialize_trackable()
--> 159     return obj._checkpoint_dependencies
    160     # pylint: enable=protected-access
    161 

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/data_structures.py in __getattribute__(self, name)
    740       # in particular seems to look up properties on the wrapped object instead
    741       # of the wrapper without this logic.
--> 742       return object.__getattribute__(self, name)
    743     else:
    744       return super(_DictWrapper, self).__getattribute__(name)

~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/data_structures.py in _checkpoint_dependencies(self)
    781           "mutable data structure.\n\nIf you don't need this dictionary "
    782           "checkpointed, wrap it in a non-trackable "
--> 783           "object; it will be subsequently ignored." % (self,))
    784     if self._self_external_modification:
    785       raise ValueError(

ValueError: Unable to save the object {140228692728528: ListWrapper([<tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_1' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_2' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_3' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_4' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_5' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_6' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_7' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_8' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_9' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_10' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_11' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_12' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_13' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_14' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_15' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_16' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_17' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_18' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_19' batch_shape=[?] event_shape=[] dtype=float32>])} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.

If you don't need this dictionary checkpointed, wrap it in a non-trackable object; it will be subsequently ignored.

What am i doing wrong here?

Thank you very much for your help!

gitlabspy commented 4 years ago

Did you solve it? 🤔

MArpogaus commented 4 years ago

Did you solve it? thinking

No. I decided to implement it as a loss function instead.

Custom Loss Function:

import tensorflow as tf

from tensorflow.keras.losses import Loss

from thesis.distributions import MixedNormal
from thesis.distributions import MixedLogNormal

class MixtedDensityLoss(Loss):
    def __init__(
            self,
            log_normal=False,
            **kwargs):
        if log_normal:
            self.mixed_density = MixedLogNormal()
        else:
            self.mixed_density = MixedNormal()

        super().__init__(**kwargs)

    def call(self, y, pvector):

        dist = self.mixed_density(pvector)
        y = tf.squeeze(y)
        nll = -dist.log_prob(y)

        return nll

Gaussian Mixture Model:

class MixedNormal():
    def __init__(self):
        pass

    def __call__(self, pvector):

        mixture = self.gen_mixture(pvector)

        return mixture

    def slice_parameter_vectors(self, pvector):
        """ Returns an unpacked list of paramter vectors.
        """
        num_dist = pvector.shape[1]
        sliced_pvectors = []
        for d in range(num_dist):
            sliced_pvector = [pvector[:, d, p] for p in range(3)]
            sliced_pvectors.append(sliced_pvector)
        return sliced_pvectors

    def gen_mixture(self, out):
        pvs = self.slice_parameter_vectors(out)
        mixtures = []

        for pv in pvs:
            logits, locs, log_scales = pv
            scales = tf.math.softmax(log_scales)
            mixtures.append(
                tfd.MixtureSameFamily(
                    mixture_distribution=tfd.Categorical(logits=logits),
                    components_distribution=tfd.Normal(
                        loc=locs,
                        scale=scales))
            )

        joint = tfd.JointDistributionSequential(
            mixtures, name='joint_mixtures')
        blkws = tfd.Blockwise(joint)
        return blkws
sid-kap commented 1 year ago

I'm having the same issue. Minimal example here: #1681