Open JonvoWoo opened 4 years ago
Yes, InferPy is compatible with tfp.layers.Convolution2DFlipout
. As an example, I give you a small example for MNIST binary classification. Note that when the NN contains a variational layer, inf.layers.Sequential
must be used so that the inference could access to the loss of such layer. By contrast, if we were using tf.keras.layers.Conv2D
, we could use directly tf.keras.Sequential
.
Note as well, that an extra dimension is added to the input of the NN with tf.expand_dims
. Convolutional layers in Keras require 4 dimensions (#batch, #channel, #dim1, #dim2)
while the output of InferPy variable is (#batch, #dim1, #dim2)
.
import tensorflow as tf
import inferpy as inf
from inferpy.data import mnist
import tensorflow_probability as tfp
import numpy as np
N = 1000 # data size
(x_train, y_train), (x_test, y_test) = mnist.load_data(num_instances=N,
digits=[0, 1], vectorize=False)
S = np.shape(x_train)[1:]
@inf.probmodel
def cnn_flipout_classifier(S):
with inf.datamodel():
x = inf.Normal(tf.ones(S), 1, name="x")
nn = inf.layers.Sequential([
tfp.layers.Convolution2DFlipout(4, kernel_size=(10,10), padding="same", activation="relu"),
tf.keras.layers.GlobalMaxPool2D(),
tf.keras.layers.Dense(1, activation='sigmoid')
])
y = inf.Normal(nn(tf.expand_dims(x, 1)), 0.001, name="y")
p = cnn_flipout_classifier(S)
# Empty Q model
@inf.probmodel
def qmodel():
pass
q = qmodel()
# set the inference algorithm
VI = inf.inference.VI(q, epochs=2000)
# learn the parameters
p.fit({"x": x_train, "y":y_train}, VI)
# evaluate the model
def evaluate(p, x, y):
N = np.shape(x)[0]
output = p.posterior_predictive("y", {"x": x[:N]}).sample()
x_pred = np.reshape(1* (output>0.5), (N,))
return np.sum(x_pred == y)/N
acc = evaluate(p, x_test[:N], y_test[:N])
print(f"accuracy = {acc}")
Hi, This library is very convenient. Does this library support convolutional neural networks, such as tfp.layers.Convolution2DFlipout?
nnetwork = inf.layers.Sequential([ tfp.layers.Convolution2DFlipout(.....), tf.keras.layers.BatchNormalization(), tf.keras.layers.Activation('relu'), tf.keras.layers.MaxPooling2D(......), tf.keras.layers.Flatten(), tfp.layers.DenseFlipout(......), tfp.layers.DenseFlipout(........) ])