Open jiaboli007 opened 2 years ago
Here's what I did, Seems to work? Changed CustomLayer to actually take in the z_mu and z_sigma. Seems old code just had them in global scope? works for me now!
# =========================
#Define custom loss
#VAE is trained using two loss functions reconstruction loss and KL divergence
#Let us add a class to define a custom layer with loss
class CustomLayer(keras.layers.Layer):
def vae_loss(self, x, z_decoded, z_mu, z_sigma):
x = K.flatten(x)
z_decoded = K.flatten(z_decoded)
# Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)
recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)
# KL divergence
kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mu) - K.exp(z_sigma), axis=-1)
return K.mean(recon_loss + kl_loss)
# add custom loss to the class
def call(self, inputs):
x = inputs[0]
z_decoded = inputs[1]
z_mu = inputs[2]
z_sigma = inputs[3]
loss = self.vae_loss(x, z_decoded, z_mu, z_sigma)
self.add_loss(loss, inputs=inputs)
return x
# apply the custom loss to the input images and the decoded latent distribution sample
y = CustomLayer()([input_img, z_decoded, z_mu, z_sigma])
# y is basically the original image after encoding input img to mu, sigma, z
# and decoding sampled z values.
#This will be used as output for vae
It works! Thank you very much, Dylan!
It worked for me too, this issue should be closed. Thank you Dylan !
I got the following error when I run the jupyter notebook of your MNIST VAE sample:
TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.float32, name=None), name='tf.math.reduce_sum/Sum:0', description="created by layer 'tf.math.reduce_sum'"), an intermediate Keras symbolic input/output, to a TF API that does not allow registering custom dispatchers, such as
tf.cond
,tf.function
, gradient tapes, ortf.map_fn
. Keras Functional model construction only supports TF API calls that do support dispatching, such astf.math.add
ortf.reshape
. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. You can work around this limitation by putting the operation in a custom Keras layercall
and calling that layer on this symbolic input/output.The error happens at
I am using the latest version of TF (2.7.0) and Keras (2.8.0). I also tried older versions of both TF and Keras with no luck.