from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import numpy as np
import tensorflow as tf
from tensorflow_probability import distributions as tfd
def elu_plus_one_plus_epsilon(x):
"""ELU activation with a very small addition to help prevent
NaN in loss."""
return keras.backend.elu(x) + 1 + keras.backend.epsilon()
class MDN(layers.Layer):
"""A Mixture Density Network Layer for Keras.
This layer has a few tricks to avoid NaNs in the loss function when training:
- Activation for variances is ELU + 1 + 1e-8 (to avoid very small values)
- Mixture weights (pi) are trained in as logits, not in the softmax space.
A loss function needs to be constructed with the same output dimension and number of mixtures.
A sampling function is also provided to sample from distribution parametrised by the MDN outputs.
"""
def __init__(self, output_dimension, num_mixtures, **kwargs):
self.output_dim = output_dimension
self.num_mix = num_mixtures
with tf.name_scope('MDN'):
self.mdn_mus = layers.Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation
self.mdn_sigmas = layers.Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation
self.mdn_pi = layers.Dense(self.num_mix, name='mdn_pi') # mix vals, logits
super(MDN, self).__init__(**kwargs)
def build(self, input_shape):
with tf.name_scope('mus'):
self.mdn_mus.build(input_shape)
with tf.name_scope('sigmas'):
self.mdn_sigmas.build(input_shape)
with tf.name_scope('pis'):
self.mdn_pi.build(input_shape)
super(MDN, self).build(input_shape)
@property
def trainable_weights(self):
return self.mdn_mus.trainable_weights + self.mdn_sigmas.trainable_weights + self.mdn_pi.trainable_weights
@property
def non_trainable_weights(self):
return self.mdn_mus.non_trainable_weights + self.mdn_sigmas.non_trainable_weights + self.mdn_pi.non_trainable_weights
def call(self, x, mask=None):
with tf.name_scope('MDN'):
mdn_out = layers.concatenate([self.mdn_mus(x),
self.mdn_sigmas(x),
self.mdn_pi(x)],
name='mdn_outputs')
return mdn_out
def compute_output_shape(self, input_shape):
"""Returns output shape, showing the number of mixture parameters."""
return (input_shape[0], (2 * self.output_dim * self.num_mix) + self.num_mix)
def get_config(self):
config = {
"output_dimension": self.output_dim,
"num_mixtures": self.num_mix
}
base_config = super(MDN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# @classmethod
# def from_config(cls, config):
# return cls(**config)
def get_mixture_loss_func(output_dim, num_mixes):
"""Construct a loss functions for the MDN layer parametrised by number of mixtures."""
# Construct a loss function with the right number of mixtures and outputs
def mdn_loss_func(y_true, y_pred):
# Reshape inputs in case this is used in a TimeDistribued layer
y_pred = tf.reshape(y_pred, [-1, (2 * num_mixes * output_dim) + num_mixes], name='reshape_ypreds')
y_true = tf.reshape(y_true, [-1, output_dim], name='reshape_ytrue')
# Split the inputs into paramaters
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim,
num_mixes * output_dim,
num_mixes],
axis=-1, name='mdn_coef_split')
# Construct the mixture models
cat = tfd.Categorical(logits=out_pi)
component_splits = [output_dim] * num_mixes
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
coll = [tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
in zip(mus, sigs)]
mixture = tfd.Mixture(cat=cat, components=coll)
loss = mixture.log_prob(y_true)
loss = tf.negative(loss)
loss = tf.reduce_mean(loss)
return loss
# Actually return the loss function
with tf.name_scope('MDN'):
return mdn_loss_func
error
21:56:16 >>>begin
WARNING:tensorflow:From /Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_diag.py:167: calling LinearOperator.__init__ (from tensorflow.python.ops.linalg.linear_operator) with graph_parents is deprecated and will be removed in a future version.
Instructions for updating:
Do not pass `graph_parents`. They will no longer be used.
WARNING:tensorflow:From /Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/affine_linear_operator.py:116: LinearOperator.graph_parents (from tensorflow.python.ops.linalg.linear_operator) is deprecated and will be removed in a future version.
Instructions for updating:
Do not call `graph_parents`.
WARNING:tensorflow:From /Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/mixture.py:154: Categorical.event_size (from tensorflow_probability.python.distributions.categorical) is deprecated and will be removed after 2019-05-19.
Instructions for updating:
The `event_size` property is deprecated. Use `num_categories` instead. They have the same value, but `event_size` is misnamed.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-6610f0ced4af> in <module>
76 optimizer=keras.optimizers.Adam(lr=0.001)
77 )
---> 78 history=model.fit(
79 X_train_scaled,y_train_scaled,
80 epochs=1000,
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
869 # This is the first call of __call__, so we have to initialize.
870 initializers = []
--> 871 self._initialize(args, kwds, add_initializers_to=initializers)
872 finally:
873 # At this point we know that the initialization is complete (or less
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
723 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
724 self._concrete_stateful_fn = (
--> 725 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
726 *args, **kwds))
727
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2967 args, kwargs = None, None
2968 with self._lock:
-> 2969 graph_function, _ = self._maybe_define_function(args, kwargs)
2970 return graph_function
2971
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3359
3360 self._function_cache.missed.add(call_context_key)
-> 3361 graph_function = self._create_graph_function(args, kwargs)
3362 self._function_cache.primary[cache_key] = graph_function
3363
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3194 arg_names = base_arg_names + missing_arg_names
3195 graph_function = ConcreteFunction(
-> 3196 func_graph_module.func_graph_from_py_func(
3197 self._name,
3198 self._python_function,
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
988 _, original_func = tf_decorator.unwrap(python_func)
989
--> 990 func_outputs = python_func(*func_args, **func_kwargs)
991
992 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
632 xla_context.Exit()
633 else:
--> 634 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
635 return out
636
~/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
975 except Exception as e: # pylint:disable=broad-except
976 if hasattr(e, "ag_error_metadata"):
--> 977 raise e.ag_error_metadata.to_exception(e)
978 else:
979 raise
TypeError: in user code:
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:805 train_function *
return step_function(self, iterator)
/Users/vector/Google 云端硬盘/Study/Rocky_exoplanets/mdn__/__init__.py:101 loss_func *
loss = mixture.log_prob(y_true)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py:866 log_prob **
return self._call_log_prob(value, name, **kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py:848 _call_log_prob
return self._log_prob(value, **kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/mixture.py:285 _log_prob
distribution_log_probs = [d.log_prob(x) for d in self.components]
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/mixture.py:285 <listcomp>
distribution_log_probs = [d.log_prob(x) for d in self.components]
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py:866 log_prob
return self._call_log_prob(value, name, **kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py:848 _call_log_prob
return self._log_prob(value, **kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/internal/distribution_util.py:2094 _fn
return fn(*args, **kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/mvn_linear_operator.py:210 _log_prob
return super(MultivariateNormalLinearOperator, self)._log_prob(x)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py:401 _log_prob
x = self.bijector.inverse(y, **bijector_kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py:977 inverse
return self._call_inverse(y, name, **kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py:946 _call_inverse
mapping = self._lookup(y=y, kwargs=kwargs)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py:1346 _lookup
mapping = self._from_y[y].get(subkey, mapping).merge(y=y)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py:151 __getitem__
return super(WeakKeyDefaultDict, self).__getitem__(weak_key)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow_probability/python/bijectors/bijector.py:181 __hash__
return hash(x)
/Users/vector/miniforge3/envs/tf_macos/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:830 __hash__
raise TypeError("Tensor is unhashable. "
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
code
mdn code as following
error
can anyone help me? thanks