asnelt / mmae

Package for Multimodal Autoencoders in TensorFlow / Keras
GNU General Public License v3.0
18 stars 12 forks source link

Add example of using mmae for supervised learning task #1

Closed rezacsedu closed 4 years ago

rezacsedu commented 5 years ago

Please add an example of using mmae for supervised learning task (e.g. classification)

asnelt commented 4 years ago

Supervised learning is not the focus of the mmae package, so I would not add such an example to the package.

Below, I paste an example where the mmae package is used to find a low dimensional latent representation of MNIST images. Then a classifier is trained on the latent representations. Multiple modalities are "simulated" by splitting the MNIST images in half and using different output activations and loss functions for each half. Classification accuracy is at around 83%. Using the original images instead yields around 90%. The point here is that we can still do relatively well when given 8-dimensional image representations instead of the full 28-by-28 images.

import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from mmae.multimodal_autoencoder import MultimodalAutoencoder

# Load example data
(x_train, y_train), (x_validation, y_validation) = mnist.load_data()
# Scale pixel values to range [0, 1]
x_train = x_train.astype('float32') / 255.0
x_validation = x_validation.astype('float32') / 255.0
# Flatten images
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_validation = x_validation.reshape((len(x_validation),
                                     np.prod(x_validation.shape[1:])))
# Split each image: each split represents one modality
split = int(x_train.shape[1] / 2)
# Multimodal training data
data = [x_train[:, :split], x_train[:, split:]]
# Multimodal validation data
validation_data = [x_validation[:, :split], x_validation[:, split:]]
# Set network parameters
input_shapes = [split, split]
# Number of units of each layer of encoder network
hidden_dims = [128, 64, 8]
# Output activation functions for each modality
output_activations = ['sigmoid', 'linear']
# Name of Keras optimizer
optimizer = 'adam'
# Loss functions corresponding to a noise model for each modality
loss = ['bernoulli_divergence', 'gaussian_divergence']
# Construct autoencoder network
autoencoder = MultimodalAutoencoder(input_shapes, hidden_dims,
                                    output_activations)
autoencoder.compile(optimizer, loss)
# Train model where input and output are the same
autoencoder.fit(data, epochs=100, batch_size=256,
                validation_data=validation_data)
# Get latent representations
latent_data = autoencoder.encode(data)
latent_validation_data = autoencoder.encode(validation_data)

# Train simple classifier on latent representation
classifier = Sequential([Dense(10, activation='softmax')])
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
                   metrics=['accuracy'])
classifier.fit(latent_data, y_train, epochs=10)
# Evaluate validation performance
classifier.evaluate(latent_validation_data, y_validation)

For visualizing reconstructions, we can do:

import matplotlib.pyplot as plt

decoded_imgs = autoencoder.predict(validation_data)
# Concatenate outputs
decoded_imgs = np.concatenate((decoded_imgs[0], decoded_imgs[1]), axis=1)
n = 10  # Number of digits to display
plt.figure(figsize=(20, 4))
for i in range(n):
    # Display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_validation[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    # Display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

You can clearly see the split in the plot.

rezacsedu commented 4 years ago

@asnelt, thanks a million for adding these two examples. Really appreciated!