secondmind-labs / GPflux

Deep GPs built on top of TensorFlow/Keras and GPflow
https://secondmind-labs.github.io/GPflux/
Apache License 2.0
120 stars 24 forks source link

Deep GP fit on Step Data #75

Open Aadesh-1404 opened 2 years ago

Aadesh-1404 commented 2 years ago

Describe the bug

I want to fit a Deep GP on step data, so I am using the method shown in GPflux tutorial on the motorcycle dataset. But the fit is not as expected, as shown in Prof.Neil Lawerence's blog. I can fit using PyDeepGP. I have attached the code used by me in GPflux and PyDeepGP

To reproduce Steps to reproduce the behaviour:

GPflux Implementation ``` ``` ruby try: import gpflux except ModuleNotFoundError: %pip install gpflux import gpflux from gpflux.architectures import Config, build_constant_input_dim_deep_gp from gpflux.models import DeepGP try: import tensorflow as tf except ModuleNotFoundError: %pip install tensorflow import tensorflow as tf import numpy as np import pandas as pd import gpflow import gpflux from gpflux.architectures import Config, build_constant_input_dim_deep_gp from gpflux.models import DeepGP tf.keras.backend.set_floatx("float64") tf.get_logger().setLevel("INFO") ## Data num_low = 25 num_high = 25 gap = -.1 noise = 0.0001 x = np.vstack((np.linspace(-1, -gap/2.0, num_low)[:, np.newaxis], np.linspace(gap/2.0, 1, num_high)[:, np.newaxis])).reshape(-1,) y = np.vstack((np.zeros((num_low, 1)), np.ones((num_high, 1)))) scale = np.sqrt(y.var()) offset = y.mean() yhat = ((y-offset)/scale).reshape(-1,) ## Model config = Config( num_inducing=x.shape[0], inner_layer_qsqrt_factor=1e-3, likelihood_noise_variance=1e-3, whiten=True ) deep_gp: DeepGP = build_constant_input_dim_deep_gp( np.array(x.reshape(-1, 1)), num_layers=4, config=config) training_model: tf.keras.Model =deep_gp.as_training_model() training_model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01)) callbacks = [ tf.keras.callbacks.ReduceLROnPlateau("loss", factor=0.95, patience=3, min_lr=1e-6, verbose=0), gpflux.callbacks.TensorBoard(), tf.keras.callbacks.ModelCheckpoint(filepath="ckpts/", save_weights_only=True, verbose=0),] history = training_model.fit( {"inputs": x.reshape(-1, 1), "targets": y.reshape(-1, 1)}, batch_size=6, epochs=1000, callbacks=callbacks, verbose=0, ) ## Predict def plot(model, X, Y, ax=None): if ax is None: fig, ax = plt.subplots() x = X x_margin = 1.0 N = 50 X = np.linspace(X.min() - x_margin, X.max() + x_margin, N).reshape(-1, 1) out = model(X) mu = out.f_mean.numpy().squeeze() var = out.f_var.numpy().squeeze() X = X.squeeze() lower = mu - 2 * np.sqrt(var) upper = mu + 2 * np.sqrt(var) ax.set_ylim(Y.min() - 0.5, Y.max() + 0.5) ax.plot(x, Y, "kx", alpha=0.5) ax.plot(X, mu, "C1") ax.set_xlim(-2, 2) ax.fill_between(X, lower, upper, color="C1", alpha=0.3) prediction_model = deep_gp.as_prediction_model() plot(prediction_model, x.reshape(-1, 1), y.reshape(-1, 1)) ```

Plot obtained as a result of the above code:

deep_gp_step_tut2

Expected behaviour

PyDeepGP Implementation ``` ``` ruby try: import deepgp except ModuleNotFoundError: %pip install git+https://github.com/SheffieldML/PyDeepGP.git import deepgp try: import GPy except ModuleNotFoundError: %pip install -qq GPy import GPy try: import tinygp except ModuleNotFoundError: %pip install -q tinygp import tinygp import seaborn as sns import jax import jax.numpy as jnp import matplotlib.pyplot as plt from tinygp import kernels, GaussianProcess from jax.config import config import numpy as np try: import jaxopt except ModuleNotFoundError: %pip install jaxopt import jaxopt config.update("jax_enable_x64", True) num_low = 25 num_high = 25 gap = -0.1 noise = 0.0001 x = jnp.vstack( (jnp.linspace(-1, -gap / 2.0, num_low)[:, jnp.newaxis], jnp.linspace(gap / 2.0, 1, num_high)[:, jnp.newaxis]) ).reshape( -1, ) y = jnp.vstack((jnp.zeros((num_low, 1)), jnp.ones((num_high, 1)))) scale = jnp.sqrt(y.var()) offset = y.mean() yhat = ((y - offset) / scale).reshape( -1, ) xnew = jnp.vstack( (jnp.linspace(-2, -gap / 2.0, 25)[:, jnp.newaxis], jnp.linspace(gap / 2.0, 2, 25)[:, jnp.newaxis]) ).reshape( -1, ) num_hidden = 3 latent_dim = 1 kernels = [*[GPy.kern.RBF(latent_dim, ARD=True)] * num_hidden] # hidden kernels kernels.append(GPy.kern.RBF(np.array(x.reshape(-1, 1)).shape[1])) # we append a kernel for the input layer m = deepgp.DeepGP( [y.reshape(-1, 1).shape[1], *[latent_dim] * num_hidden, x.reshape(-1, 1).shape[1]], X=np.array(x.reshape(-1, 1)), # training input Y=np.array(y.reshape(-1, 1)), # training outout inits=[*["PCA"] * num_hidden, "PCA"], # initialise layers kernels=kernels, num_inducing=x.shape[0], back_constraint=False, ) m.initialize_parameter() def optimise_dgp(model, messages=True): """Utility function for optimising deep GP by first reinitiailising the Gaussian noise at each layer (for reasons pertaining to stability) """ model.initialize_parameter() for layer in model.layers: layer.likelihood.variance.constrain_positive(warning=False) layer.likelihood.variance = 1.0 # small variance may cause collapse model.optimize(messages=messages, max_iters=10000) optimise_dgp(m, messages=True) mu_dgp, var_dgp = m.predict(xnew.reshape(-1, 1)) plt.figure() latexify(width_scale_factor=2, fig_height=1.75) plt.plot(xnew, mu_dgp, "blue") plt.scatter(x, y, c="r", s=marksize) plt.fill_between( xnew.flatten(), mu_dgp.flatten() - 1.96 * jnp.sqrt(var_dgp.flatten()), mu_dgp.flatten() + 1.96 * jnp.sqrt(var_dgp.flatten()), alpha=0.3, color="C1", ) sns.despine() legendsize = 4.5 if is_latexify_enabled() else 9 plt.legend(labels=["Mean", "Data", "Confidence"], loc=2, prop={"size": legendsize}, frameon=False) plt.xlabel("$x$") plt.ylabel("$y$") sns.despine() plt.show() ```

Plot obtained from above code deep_gp_step_pydeepgp

System information

sebastianober commented 2 years ago

Hi @Aadesh-1404 ,

This is actually a difference in the models - the DGP model used in PyDeepGP is from Damianou and Lawrence (2013) and uses latent variables in the intermediate layers, whereas the model we follow is from Salimbeni and Deisenroth (2017), and doesn't use latent variables. This means that the fits will be different. However, you should be able to get similar fits (but not exactly the same) using GPflux by following the tutorial https://secondmind-labs.github.io/GPflux/notebooks/deep_cde.html

References Damianou and Lawrence (2013): http://proceedings.mlr.press/v31/damianou13a Salimbeni and Deisenroth (2017): https://arxiv.org/abs/1705.08933