ageron / handson-ml2

A series of Jupyter notebooks that walk you through the fundamentals of Machine Learning and Deep Learning in Python using Scikit-Learn, Keras and TensorFlow 2.
Apache License 2.0
27.26k stars 12.6k forks source link

[QUESTION] Plotting loss with the Deep Convolutional GAN #563

Open Reisa14 opened 2 years ago

Reisa14 commented 2 years ago

Describe what is unclear to you When creating autoencoders, we were using fit() to produce the history, which could then be used to plot the loss across the training and validation periods, such as on p. 590:

history =, X_train, epochs=25, batch_size=128, 
                             validation_data=(X_valid, X_valid))

However, when creating the GANs and deep convolutional GANs, we do not use fit(), we use the custom train_gan function:

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=20):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))
        for X_batch in dataset:
            # phase 1 - training the discriminator
            X_batch = tf.cast(X_batch, tf.float32)
            noise = tf.random.normal(shape=[batch_size, codings_size])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real, y1)
            # phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y2)
        plot_multiple_images(generated_images, 8)

If we wanted to plot the loss of both the discriminator and generator across all epochs in the example on page 599, how would we go about this?
