Open MArpogaus opened 4 years ago
Did you solve it? 🤔
Did you solve it? thinking
No. I decided to implement it as a loss function instead.
import tensorflow as tf
from tensorflow.keras.losses import Loss
from thesis.distributions import MixedNormal
from thesis.distributions import MixedLogNormal
class MixtedDensityLoss(Loss):
def __init__(
self,
log_normal=False,
**kwargs):
if log_normal:
self.mixed_density = MixedLogNormal()
else:
self.mixed_density = MixedNormal()
super().__init__(**kwargs)
def call(self, y, pvector):
dist = self.mixed_density(pvector)
y = tf.squeeze(y)
nll = -dist.log_prob(y)
return nll
class MixedNormal():
def __init__(self):
pass
def __call__(self, pvector):
mixture = self.gen_mixture(pvector)
return mixture
def slice_parameter_vectors(self, pvector):
""" Returns an unpacked list of paramter vectors.
"""
num_dist = pvector.shape[1]
sliced_pvectors = []
for d in range(num_dist):
sliced_pvector = [pvector[:, d, p] for p in range(3)]
sliced_pvectors.append(sliced_pvector)
return sliced_pvectors
def gen_mixture(self, out):
pvs = self.slice_parameter_vectors(out)
mixtures = []
for pv in pvs:
logits, locs, log_scales = pv
scales = tf.math.softmax(log_scales)
mixtures.append(
tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=logits),
components_distribution=tfd.Normal(
loc=locs,
scale=scales))
)
joint = tfd.JointDistributionSequential(
mixtures, name='joint_mixtures')
blkws = tfd.Blockwise(joint)
return blkws
I'm having the same issue. Minimal example here: #1681
Hello!
I am currently trying to use
JointDistributionSequential
to predict multiple distributions using a Mixture Density Network.Minimal example:
Training with the TensorFlow keras API works as expected, but when i use
keras.callbacks.ModelCheckpoint
orgmm_model.save_weights('test')
to save the weights i get the following error:What am i doing wrong here?
Thank you very much for your help!