wesselb / stheno

Gaussian process modelling in Python
MIT License
214 stars 18 forks source link

Documentation about Multi-Output Regression #20

Open vdsmax opened 2 years ago

vdsmax commented 2 years ago

Hi @wesselb,

I am trying to use your example of Multi-Output Regression with some data I have. I don't understand how to correctly give them to the VGP and them make a prediction. My data as input x_obs are not the same, so it's not exactly as the example. I have nine x observation as [x1,x2,x3,x4,x5,x6,x7,x8,x9] with their y observation as [y1,y2,y3,y4,y5,y6,y7,y8,y9]. Also, with your example provided, is it possible to optimize some hyperparameters if we had some in the VGP ?

Here are my code I was trying to use, with 3 different outputs to simulate data. Thank you in advance for your help.

import matplotlib.pyplot as plt
from wbml.plot import tweak
from stheno import B, Measure, GP, EQ, Delta, Matern52

class VGP:
    """A vector-valued GP."""

    def __init__(self, ps):
        self.ps = ps

    def __add__(self, other):
        return VGP([f + g for f, g in zip(self.ps, other.ps)])

    def lmatmul(self, A):
        m, n = A.shape
        ps = [0 for _ in range(m)]
        for i in range(m):
            for j in range(n):
                ps[i] += A[i, j] * self.ps[j]
        return VGP(ps)

# Define points to predict at.
x = B.linspace(0, 10, 5)

# Create some sample data.
x1 = np.atleast_2d(np.linspace(0, 10, 5)).T
x2 = np.atleast_2d(np.linspace(0, 9, 5)).T
x3 = np.atleast_2d(np.linspace(0, 7, 5)).T
y1 = np.atleast_2d(np.linspace(0, 10, 5)).T
y2 = np.atleast_2d(np.linspace(0, 10, 5)).T
y3 = np.atleast_2d(np.linspace(0, 10, 5)).T

x_obs = [x1,x2,x3]
y_obs = [y1,y2,y3]

# Model parameters:
m = 3
p = 3
H = B.randn(p, m)

with Measure() as prior:
    # Construct latent functions.
    us = VGP([GP(Matern52()) for _ in range(m)])
    # Construct multi-output prior.
    fs = us.lmatmul(H)
    # Construct noise.
    e = VGP([GP(0 * Delta()) for _ in range(p)])
    # Construct observation model.
    ys = e + fs

# Sample a true, underlying function and observations.
samples = prior.sample(*(p(x) for p in zip(fs.ps)), *(p(x_obs) for p, x_obs in zip(ys.ps, x_obs)))
fs_true, ys_obs = samples[:p], samples[p:]

# Compute the posterior and make predictions.
post = prior.condition(*((p(x_obs), y_obs) for p, y_obs, x_obs in zip(ys.ps, ys_obs, x_obs)))
preds = [post(p(x)) for p in fs.ps]

# Plot results.
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
    plt.plot(x, f, label="True", style="test")
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = pred.marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()

plt.figure(figsize=(10, 6))
for i in range(3):
    plt.subplot(3, 1, i + 1)
    plt.title(f"Output {i + 1}")
    plot_prediction(x, fs_true[i], preds[i], x_obs, ys_obs[i])
plt.show()
wesselb commented 2 years ago

Hi @vdsmax!

I've put together a simple MOGP model (not using the example) which might better suit your use case. The script uses JAX to learn hyperparameters. (You can also use another AD framework if you like.)

from stheno.jax import GP, Matern52, Measure
from varz.jax import Vars, minimise_l_bfgs_b
from wbml.plot import tweak

import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np

x1 = np.linspace(0, 10, 30)
x2 = np.linspace(0, 9, 40)
x3 = np.linspace(0, 7, 50)

# Generate some test data.
f = GP(Matern52())
y1 = f(x1, 0.2).sample().flatten()
y2 = f(x2, 0.2).sample().flatten()
y3 = f(x3, 0.2).sample().flatten()

p = 3  # Number of outputs
m = 3  # Number of latent processes

def model(vs):
    ps = vs.struct

    with Measure() as prior:
        # Create independent processes with learnable length scales initialised to `1`.
        us = [
            GP(Matern52().stretch(ps_u.scale.positive(1)))
            for ps_u, _ in zip(ps.us, range(p))
        ]

        # Mix processes together to induce correlations between the outputs.
        H = ps.mixing_matrix.unbounded(shape=(p, m))
        fs = [0 for _ in range(p)]
        for i in range(p):
            for j in range(m):
                fs[i] = fs[i] + H[i, j] * us[j]

        # Create learnable observation noises initialised to `0.1`
        noises = ps.noises.positive(0.1, shape=(p,))

    return prior, fs, noises

def objective(vs):
    prior, fs, noises = model(vs)
    return -prior.logpdf(
        (fs[0](x1, noises[0]), y1),
        (fs[1](x2, noises[1]), y2),
        (fs[2](x3, noises[2]), y3),
    )

# Perform learning.
vs = Vars(jnp.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print()  # Display learned parameters.

# Compute posterior and predictions.
prior, fs, noises = model(vs)
posterior = prior | (
    (fs[0](x1, noises[0]), y1),
    (fs[1](x2, noises[1]), y2),
    (fs[2](x3, noises[2]), y3),
)
f1_post = posterior(fs[0])
f2_post = posterior(fs[1])
f3_post = posterior(fs[2])

def plot_posterior(x, f, x_obs=None, y_obs=None):
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = f(x).marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()

# Plot results.
plt.figure(figsize=(10, 6))
x_to_plot = np.linspace(0, 10, 200)
plt.subplot(3, 1, 1)
plt.title("Output 1")
plot_posterior(x_to_plot, f1_post, x1, y1)
plt.subplot(3, 1, 2)
plt.title("Output 2")
plot_posterior(x_to_plot, f2_post, x2, y2)
plt.subplot(3, 1, 3)
plt.title("Output 3")
plot_posterior(x_to_plot, f3_post, x3, y3)
plt.show()

The script produces the following plot:

Output

Let me know if this suits your needs. :)

vdsmax commented 2 years ago

Thank you very much for your code example. It is running on my side too, and I have the same results by using my CPU.

Because the computational time is high for nine inputs by using a CPU, I wanted to use my GPU to see if it will be faster. I followed the steps to use CUDA with the Jax library and was able to link both of them. However, by using the same code as you give me, I obtained this time an error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-45889ae4f67a> in <module>
     55 # Perform learning.
     56 vs = Vars(jnp.float64)
---> 57 minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
     58 vs.print()  # Display learned parameters.
     59 

~/python-env/lib/python3.6/site-packages/varz/minimise.py in minimise_l_bfgs_b(f, vs, f_calls, iters, trace, names, jit)
     77         trace=trace,
     78         names=names,
---> 79         jit=jit,
     80     )
     81 

~/python-env/lib/python3.6/site-packages/varz/minimise.py in _minimise_l_bfgs_b(f, vs, f_calls, iters, trace, names, jit)
    154         # Run function once to ensure that all variables are initialised and
    155         # available.
--> 156         res = convert(f(vs, *args), tuple)
    157         val_init, args = res[0], res[1:]
    158 

<ipython-input-4-45889ae4f67a> in objective(vs)
     49         (fs[0](x1, noises[0]), y1),
     50         (fs[1](x2, noises[1]), y2),
---> 51         (fs[2](x3, noises[2]), y3),
     52     )
     53 

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function._BoundFunction.__call__()

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()

~/python-env/lib/python3.6/site-packages/stheno/model/measure.py in logpdf(self, *pairs)
    461         """
    462         fdd, y = combine(*pairs)
--> 463         return self(fdd).logpdf(y)
    464 
    465     @_dispatch

~/python-env/lib/python3.6/site-packages/stheno/random.py in logpdf(self, x)
    210                 B.logdet(self.var)[..., None]  # Correctly line up with `iqf_diag`.
    211                 + B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi)
--> 212                 + B.iqf_diag(self.var, B.subtract(x, self.mean))
    213             )
    214             / 2

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()

~/python-env/lib/python3.6/site-packages/matrix/ops/iqf_diag.py in iqf_diag(a, b)
     33 @B.dispatch
     34 def iqf_diag(a, b):
---> 35     return iqf_diag(a, b, b)
     36 
     37 

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()

~/python-env/lib/python3.6/site-packages/matrix/ops/iqf_diag.py in iqf_diag(a, b, c)
     20     """
     21     chol = B.cholesky(a)
---> 22     chol_b = B.solve(chol, b)
     23     if c is b:
     24         chol_c = chol_b

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()

~/python-env/lib/python3.6/site-packages/lab/util.py in wrapper(*args, **kw_args)
    212 
    213             # Retry call.
--> 214             return getattr(B, f.__name__)(*args, **kw_args)
    215 
    216         return wrapper

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()

~/python-env/lib/python3.6/site-packages/matrix/ops/solve.py in solve(a, b)
     41         )
     42     a, b = align_batch(a.mat, b)
---> 43     return Dense(B.trisolve(B.dense(a), B.dense(b), lower_a=True))
     44 
     45 

~/python-env/lib/python3.6/site-packages/plum/function.cpython-36m-x86_64-linux-gnu.so in plum.function.Function.__call__()

~/python-env/lib/python3.6/site-packages/lab/shape.py in f_wrapped(*args, **kw_args)
    183         @wraps(f)
    184         def f_wrapped(*args, **kw_args):
--> 185             return f(*(unwrap_dimension(arg) for arg in args), **kw_args)
    186 
    187         return dispatch(f_wrapped)

~/python-env/lib/python3.6/site-packages/lab/jax/linear_algebra.py in triangular_solve(a, b, lower_a)
    125         )
    126 
--> 127     return batch_computation(_triangular_solve, (a, b), (2, 2))
    128 
    129 

~/python-env/lib/python3.6/site-packages/lab/util.py in batch_computation(f, xs, ranks)
    149     for index in indices:
    150         batches.append(
--> 151             f(*[x[_translate_index(index, s)] for x, s in zip(xs, batch_shapes)])
    152         )
    153 

~/python-env/lib/python3.6/site-packages/lab/jax/linear_algebra.py in _triangular_solve(a_, b_)
    122     def _triangular_solve(a_, b_):
    123         return jsla.solve_triangular(
--> 124             a_, b_, trans="N", lower=lower_a, check_finite=False
    125         )
    126 

~/python-env/lib/python3.6/site-packages/jax/_src/scipy/linalg.py in solve_triangular(***failed resolving arguments***)
    223                      overwrite_b=False, debug=None, check_finite=True):
    224   del overwrite_b, debug, check_finite
--> 225   return _solve_triangular(a, b, trans, lower, unit_diagonal)
    226 
    227 

~/python-env/lib/python3.6/site-packages/jax/_src/api.py in cache_miss(*args, **kwargs)
    425         flat_fun, *args_flat,
    426         device=device, backend=backend, name=flat_fun.__name__,
--> 427         donated_invars=donated_invars, inline=inline)
    428     out_pytree_def = out_tree()
    429     out = tree_unflatten(out_pytree_def, out_flat)

~/python-env/lib/python3.6/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1558 
   1559   def bind(self, fun, *args, **params):
-> 1560     return call_bind(self, fun, *args, **params)
   1561 
   1562   def process(self, trace, fun, tracers, params):

~/python-env/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1549       params_tuple, out_axes_transforms)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1553 

~/python-env/lib/python3.6/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1561 
   1562   def process(self, trace, fun, tracers, params):
-> 1563     return trace.process_call(self, fun, tracers, params)
   1564 
   1565   def post_process(self, trace, out_tracers, params):

~/python-env/lib/python3.6/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    604 
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call
    608 

~/python-env/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(***failed resolving arguments***)
    593                                *unsafe_map(arg_spec, args))
    594   try:
--> 595     return compiled_fun(*args)
    596   except FloatingPointError:
    597     assert config.jax_debug_nans or config.jax_debug_infs  # compiled_fun can only raise in this case

~/python-env/lib/python3.6/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, kept_var_idx, *args)
    891           for i, x in enumerate(args)
    892           if x is not token and i in kept_var_idx))
--> 893   out_bufs = compiled.execute(input_bufs)
    894   check_special(xla_call_p.name, out_bufs)
    895   return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

RuntimeError: Internal: Unable to launch triangular solve for thunk 0x2c46c570

Do I need to add something to the code to make it work with a GPU ?

wesselb commented 2 years ago

Ouch! That doesn't look good. Could you confirm that running other JAX code on the GPU works fine? If that's the case, I can look into this more closely to see what's going on.

vdsmax commented 2 years ago

I tried some examples of JAX code with my GPU (like these one: https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) and it was working. I think the issue come from the library. I have jax-0.2.17 and jaxlib-0.1.65+cuda110 install on my computer

wesselb commented 2 years ago

Hey @vdsmax,

That's very frustrating. I'm not sure what's going wrong. I am able to run the example on my end on a GPU. I am running jaxlib-0.1.73+cuda11.cudnn82 and jax-0.2.25.

I've created a version of the example using TensorFlow. Perhaps that works for you:

from stheno.tensorflow import GP, Matern52, Measure
from varz.tensorflow import Vars, minimise_l_bfgs_b
from wbml.plot import tweak
import lab.tensorflow as B

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

B.set_global_device("gpu")

x1 = np.linspace(0, 10, 200)
x2 = np.linspace(0, 9, 200)
x3 = np.linspace(0, 7, 200)

# Generate some test data.
f = GP(Matern52())
y1 = f(x1, 0.2).sample().flatten()
y2 = f(x2, 0.2).sample().flatten()
y3 = f(x3, 0.2).sample().flatten()

p = 3  # Number of outputs
m = 3  # Number of latent processes

def model(vs):
    ps = vs.struct

    with Measure() as prior:
        # Create independent processes with learnable length scales initialised to `1`.
        us = [
            GP(Matern52().stretch(ps_u.scale.positive(1)))
            for ps_u, _ in zip(ps.us, range(p))
        ]

        # Mix processes together to induce correlations between the outputs.
        H = ps.mixing_matrix.unbounded(shape=(p, m))
        fs = [0 for _ in range(p)]
        for i in range(p):
            for j in range(m):
                fs[i] = fs[i] + H[i, j] * us[j]

        # Create learnable observation noises initialised to `0.1`
        noises = ps.noises.positive(0.1, shape=(p,))

    return prior, fs, noises

def objective(vs):
    prior, fs, noises = model(vs)
    return -prior.logpdf(
        (fs[0](x1, noises[0]), y1),
        (fs[1](x2, noises[1]), y2),
        (fs[2](x3, noises[2]), y3),
    )

# Perform learning.
vs = Vars(tf.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print()  # Display learned parameters.

# Compute posterior and predictions.
prior, fs, noises = model(vs)
posterior = prior | (
    (fs[0](x1, noises[0]), y1),
    (fs[1](x2, noises[1]), y2),
    (fs[2](x3, noises[2]), y3),
)
f1_post = posterior(fs[0])
f2_post = posterior(fs[1])
f3_post = posterior(fs[2])

def plot_posterior(x, f, x_obs=None, y_obs=None):
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = f(x).marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()

# Plot results.
plt.figure(figsize=(10, 6))
x_to_plot = np.linspace(0, 10, 200)
plt.subplot(3, 1, 1)
plt.title("Output 1")
plot_posterior(x_to_plot, f1_post, x1, y1)
plt.subplot(3, 1, 2)
plt.title("Output 2")
plot_posterior(x_to_plot, f2_post, x2, y2)
plt.subplot(3, 1, 3)
plt.title("Output 3")
plot_posterior(x_to_plot, f3_post, x3, y3)
plt.show()