PGM-Lab / InferPy

InferPy: Deep Probabilistic Modeling with Tensorflow Made Easy
Apache License 2.0
146 stars 14 forks source link

Support for convolutional neural network #206

Open JonvoWoo opened 4 years ago

JonvoWoo commented 4 years ago

Hi, This library is very convenient. Does this library support convolutional neural networks, such as tfp.layers.Convolution2DFlipout?

nnetwork = inf.layers.Sequential([ tfp.layers.Convolution2DFlipout(.....), tf.keras.layers.BatchNormalization(), tf.keras.layers.Activation('relu'), tf.keras.layers.MaxPooling2D(......), tf.keras.layers.Flatten(), tfp.layers.DenseFlipout(......), tfp.layers.DenseFlipout(........) ])

rcabanasdepaz commented 4 years ago

Yes, InferPy is compatible with tfp.layers.Convolution2DFlipout. As an example, I give you a small example for MNIST binary classification. Note that when the NN contains a variational layer, inf.layers.Sequential must be used so that the inference could access to the loss of such layer. By contrast, if we were using tf.keras.layers.Conv2D, we could use directly tf.keras.Sequential.

Note as well, that an extra dimension is added to the input of the NN with tf.expand_dims. Convolutional layers in Keras require 4 dimensions (#batch, #channel, #dim1, #dim2) while the output of InferPy variable is (#batch, #dim1, #dim2).

import tensorflow as tf
import inferpy as inf
from import mnist
import tensorflow_probability as tfp
import numpy as np

N = 1000 # data size
(x_train, y_train), (x_test, y_test) = mnist.load_data(num_instances=N,
                                  digits=[0, 1], vectorize=False)

S = np.shape(x_train)[1:]

def cnn_flipout_classifier(S):
    with inf.datamodel():
        x = inf.Normal(tf.ones(S), 1, name="x")

        nn = inf.layers.Sequential([
            tfp.layers.Convolution2DFlipout(4, kernel_size=(10,10), padding="same", activation="relu"),
            tf.keras.layers.Dense(1, activation='sigmoid')

        y = inf.Normal(nn(tf.expand_dims(x, 1)), 0.001, name="y")

p = cnn_flipout_classifier(S)

# Empty Q model
def qmodel():
q = qmodel()

# set the inference algorithm
VI = inf.inference.VI(q, epochs=2000)

# learn the parameters{"x": x_train, "y":y_train}, VI)

# evaluate the model
def evaluate(p, x, y):
    N = np.shape(x)[0]
    output = p.posterior_predictive("y", {"x": x[:N]}).sample()
    x_pred = np.reshape(1* (output>0.5), (N,))
    return np.sum(x_pred == y)/N

acc = evaluate(p, x_test[:N], y_test[:N])
print(f"accuracy = {acc}")