wesselb / stheno

Gaussian process modelling in Python
MIT License
217 stars 18 forks source link

Add hyperparameter optimization example #18

Closed patel-zeel closed 2 years ago

patel-zeel commented 2 years ago

Hi @wesselb.

Please let me know your suggestions/comments on the example. README_process.py was really handy in doing this :)

P.S.: This pull request is continued from #11.

coveralls commented 2 years ago

Pull Request Test Coverage Report for Build 1504464141


Totals Coverage Status
Change from base Build 1476935926: 0.0%
Covered Lines: 869
Relevant Lines: 869

💛 - Coveralls
wesselb commented 2 years ago

Hey @patel-zeel,

This is looking good! I'm wondering, do you think it would be useful to implement the optimisation with the usual PyTorch Adam optimiser? The way of doing is with Varz is concise, but will probably be unfamiliar for most. Having an example which uses an optimiser which is widely useful could be helpful!

patel-zeel commented 2 years ago

Yes, that sounds reasonable @wesselb. So, using variables and model definition from Varz and torch.optim.LBFGS (or torch.optim.Adam) seems okay or should we use torch.tensor variables directly?

wesselb commented 2 years ago

Perhaps strip away Varz entirely and use torch.tensor directly? What do you think?

patel-zeel commented 2 years ago

I think Varz seems useful in handling Positivity constraints painlessly. Users coming from nn community may find it as a bit of overhead to apply positivity manually. Another benefit (that I personally like) is the vs handle. We don't have to worry about which parameters to pass to which function in a slightly complicated model, we can simply pass vs to each of them. Initially, Varz may seem a bit alien, but as a user, I'd love to shift to Varz even for my non-GP torch models after knowing its benefits.

Would you like to highlight torch.tensor benefits over Varz variables?

wesselb commented 2 years ago

All good points! What about we add two examples, one in which we use Varz and which people can use to get started quickly, and one in which we use the PyTorch optimiser, which people can reference if they don't want to use Varz?

Moreover, over time, I've come to the conclusion that @parametrised is a bit too much magic and does too much in the background. I'm thinking of actually removing it because it doesn't seem to give particularly improved code (apart from that it looks nice). What are your thoughts on this? If you agree, perhaps we should remove the commented out @parametrised alternative?

Perhaps the structlike features of Varz would be nice to illustrate:

def model(vs):
    p = vs.struct
    # Varz handles positivity (and other) constraints.
    kernel = p.variance.positive() * EQ().stretch(p.scale.positive())
    return GP(kernel), p.noise.positive()

(Personally, I think this may be the cleanest pattern.)

patel-zeel commented 2 years ago

The all-inclusive approach looks great to me. So, I am adding the following variants:

I wouldn't vote for removing @parametrised from API because somehow it looks more Pythonic (or futuristic) and over time we may not know what users will choose to use. I agree that we should remove it from the example.

Structlike feature is wonderful. I was feeling a need for it. You have a great sense of what would make a good API/software :)

wesselb commented 2 years ago

The all-inclusive approach looks great to me. So, I am adding the following variants:

Excellent! I think that would be absolutely fantastic. :)

I wouldn't vote for removing @parametrised from API because somehow it looks more Pythonic (or futuristic) and over time we may not know what users will choose to use.

That's a good point too! Perhaps then best to just keep @parametrised and see whether people find it more useful.

Structlike feature is wonderful. I was feeling a need for it. You have a great sense of what would make a good API/software :)

Thank you. :) I'm still undecided about what the best way to do parameters is, but, so far, this structlike construction (module the need to do p = vs.struct) seems like it might hit the sweet spot!

patel-zeel commented 2 years ago

Added both examples. I think the torch example could be improved a bit. I am wondering if f.kernel.factor(0) like syntax can be avoided without using positive() on raw parameters again. Do you have any thoughts on this?

Also, I had very hard time finding out that parameters must be torch.Size([]) and not torch.Size([1]). I do not exactly know why but torch.Size([1]) parameters did their own weird stuff rather than converging to sensible values. What could be a possible solution to avoid this for PyTorch users? I am not sure if a similar problem exists in tf or jax too.

wesselb commented 2 years ago

Hey @patel-zeel,

Thanks for this! I've iterated once on your examples: I've refactored the PyTorch one to use a more common pattern with torch.nn.Module and made the Varz one line up. I also tried to shorten them as much as possible and as similar as possible to the other examples. What do you think of the two below?

Torch:

import lab as B
import matplotlib.pyplot as plt
import torch
from wbml.plot import tweak

from stheno.torch import EQ, GP

# Increase regularisation because PyTorch defaults to 32-bit floats.
B.epsilon = 1e-6

# Define points to predict at.
x = torch.linspace(0, 2, 100)
x_obs = torch.linspace(0, 2, 50)

# Sample a true, underlying function and observations with observation noise `0.05`.
f_true = torch.sin(5 * x)
y_obs = torch.sin(5 * x_obs) + 0.05 ** 0.5 * torch.randn(50)

class Model(torch.nn.Module):
    """A GP model with learnable parameters."""

    def __init__(self, init_var=0.3, init_scale=1, init_noise=0.2):
        super().__init__()
        # Ensure that the parameters are positive and make them learnable.
        self.log_var = torch.nn.Parameter(torch.log(torch.tensor(init_var)))
        self.log_scale = torch.nn.Parameter(torch.log(torch.tensor(init_scale)))
        self.log_noise = torch.nn.Parameter(torch.log(torch.tensor(init_noise)))

    def construct(self):
        kernel = torch.exp(self.log_var) * EQ().stretch(torch.exp(self.log_scale))
        return GP(kernel), torch.exp(self.log_noise)

model = Model()
f, noise = model.construct()

# Condition on observations and make predictions before optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_before = f, noise
pred_before = f_post(x, noise).marginal_credible_bounds()

# Perform optimisation.
opt = torch.optim.Adam(model.parameters(), lr=5e-2)
for _ in range(1000):
    opt.zero_grad()
    f, noise = model.construct()
    loss = -f(x_obs, noise).logpdf(y_obs)
    loss.backward()
    opt.step()

f, noise = model.construct()

# Condition on observations and make predictions after optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_after = f, noise
pred_after = f_post(x, noise).marginal_credible_bounds()

def plot_prediction(prior, pred):
    f, noise = prior
    mean, lower, upper = pred
    plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    plt.plot(x, f_true, label="True", style="test")
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    plt.ylim(-2, 2)
    plt.text(
        0.02,
        0.02,
        f"var = {f.kernel.factor(0):.2f}, "
        f"scale = {f.kernel.factor(1).stretches[0]:.2f}, "
        f"noise = {noise:.2f}",
        transform=plt.gca().transAxes,
    )
    tweak()

# Plot result.
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title("Before optimisation")
plot_prediction(prior_before, pred_before)
plt.subplot(1, 2, 2)
plt.title("After optimisation")
plot_prediction(prior_after, pred_after)
plt.savefig("readme_example13_optimisation_torch.png")
plt.show()

Varz:

import lab as B
import matplotlib.pyplot as plt
import torch
from varz import Vars, minimise_l_bfgs_b, parametrised, Positive
from wbml.plot import tweak

from stheno.torch import EQ, GP

# Increase regularisation because PyTorch defaults to 32-bit floats.
B.epsilon = 1e-6

# Define points to predict at.
x = torch.linspace(0, 2, 100)
x_obs = torch.linspace(0, 2, 50)

# Sample a true, underlying function and observations with observation noise `0.05`.
f_true = torch.sin(5 * x)
y_obs = torch.sin(5 * x_obs) + 0.05 ** 0.5 * torch.randn(50)

def model(vs):
    """Construct a model with learnable parameters."""
    p = vs.struct  # Varz handles positivity (and other) constraints.
    kernel = p.variance.positive() * EQ().stretch(p.scale.positive())
    return GP(kernel), p.noise.positive()

@parametrised
def model_alternative(vs, scale: Positive, variance: Positive, noise: Positive):
    """Equivalent to :func:`model`, but with `@parametrised`."""
    kernel = variance * EQ().stretch(scale)
    return GP(kernel), noise

vs = Vars(torch.float32)
f, noise = model(vs)

# Condition on observations and make predictions before optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_before = f, noise
pred_before = f_post(x, noise).marginal_credible_bounds()

def objective(vs):
    f, noise = model(vs)
    evidence = f(x_obs, noise).logpdf(y_obs)
    return -evidence

# Learn hyperparameters.
minimise_l_bfgs_b(objective, vs)

f, noise = model(vs)

# Condition on observations and make predictions after optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_after = f, noise
pred_after = f_post(x, noise).marginal_credible_bounds()

def plot_prediction(prior, pred):
    f, noise = prior
    mean, lower, upper = pred
    plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    plt.plot(x, f_true, label="True", style="test")
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    plt.ylim(-2, 2)
    plt.text(
        0.02,
        0.02,
        f"var = {f.kernel.factor(0):.2f}, "
        f"scale = {f.kernel.factor(1).stretches[0]:.2f}, "
        f"noise = {noise:.2f}",
        transform=plt.gca().transAxes,
    )
    tweak()

# Plot result.
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title("Before optimisation")
plot_prediction(prior_before, pred_before)
plt.subplot(1, 2, 2)
plt.title("After optimisation")
plot_prediction(prior_after, pred_after)
plt.savefig("readme_example14_optimisation_varz.png")
plt.show()
wesselb commented 2 years ago

Also, I had very hard time finding out that parameters must be torch.Size([]) and not torch.Size([1]).

Ouch! I can confirm here that something weird happens in that case. At first it wasn't clear what's going on, but I think I now understand. First, note that you can use heterogenous observation noise by feeding a noise vector:

>>> x = np.linspace(0, 5, 3)

>>> f = GP(EQ())

>>> f(x, 1)  # Homegeneous observation noise
<FDD:
 process=GP(0, EQ()),
 input=array([0. , 2.5, 5. ]),
 noise=<diagonal matrix: batch=(), shape=(3, 3), dtype=int64
        diag=[1 1 1]>>

>>> f(x, 1).noise  # Homegeneous observation noise
<diagonal matrix: batch=(), shape=(3, 3), dtype=int64
 diag=[1 1 1]>

>>> f(x, np.array([1, 1.1, 1.2]))  # Heterogeneous observation noise
<FDD:
 process=GP(0, EQ()),
 input=array([0. , 2.5, 5. ]),
 noise=<diagonal matrix: batch=(), shape=(3, 3), dtype=float64
        diag=[1.  1.1 1.2]>>

>>> f(x, np.array([1, 1.1, 1.2])).noise  # Heterogeneous observation noise
<diagonal matrix: batch=(), shape=(3, 3), dtype=float64
 diag=[1.  1.1 1.2]>

If you want to use homogeneous observation noise, but accidentially give np.array([1]), then something weird happens. To begin with, note that the noise becomes a 1x1 diagonal matrix:

>>> f(x, np.array([1])).noise 
<diagonal matrix: batch=(), shape=(1, 1), dtype=int64
 diag=[1]>

This is fine in principle. However, when this matrix is added to an nxn matrix, broadcasting mechanics dictate that the matrix becomes a matrix full of ones rather than a diagonal matrix!

>>> f(x, np.array([1])).noise + np.zeros((5, 5))
<dense matrix: batch=(), shape=(5, 5), dtype=float64
 mat=[[1. 1. 1. 1. 1.]
      [1. 1. 1. 1. 1.]
      [1. 1. 1. 1. 1.]
      [1. 1. 1. 1. 1.]
      [1. 1. 1. 1. 1.]]>

This means that all the noises are correlated, which will give you weird logpdf values and predictions and likely explain what you're seeing.

I think this behaviour is correct, but, as you experienced, it is easily possible to run into this issue unknowingly, perhaps too easily. What do you think is the best solution? Perhaps a "gotchas" section in the README?

wesselb commented 2 years ago

I am wondering if f.kernel.factor(0) like syntax can be avoided without using positive() on raw parameters again. Do you have any thoughts on this?

Hmm, in the new pattern, you could store self.scale = torch.exp(self.log_scale) in Model.construct, so you can access the variable after construction. How would that sound?

patel-zeel commented 2 years ago

I've iterated once on your examples: I've refactored the PyTorch one to use a more common pattern with torch.nn.Module and made the Varz one line up. I also tried to shorten them as much as possible and as similar as possible to the other examples. What do you think of the two below?

These look fantastic to me!

I think this behaviour is correct, but, as you experienced, it is easily possible to run into this issue unknowingly, perhaps too easily. What do you think is the best solution? Perhaps a "gotchas" section in the README?

Thank you for taking the time to explain this in great detail. Yes, "common mistakes" or "gotchas" would be amazing.

Hmm, in the new pattern, you could store self.scale = torch.exp(self.log_scale) in Model.construct, so you can access the variable after construction. How would that sound?

In this case, storing self.scale individually for prior_before and prior_after is difficult. Any ideas to do that in an elegant way?

wesselb commented 2 years ago

Thank you for taking the time to explain this in great detail. Yes, "common mistakes" or "gotchas" would be amazing.

Perfect! I'll add that soon. :)

In this case, storing self.scale individually for prior_before and prior_after is difficult. Any ideas to do that in an elegant way?

Ah, that's a good point. Hmm, perhaps the examples are now fine as they are? What do you think? I'm happy to merge them in their current state!

patel-zeel commented 2 years ago

Perfect! I'll add that soon. :)

Thank you :)

Hmm, perhaps the examples are now fine as they are? What do you think? I'm happy to merge them in their current state!

Yes. The current version looks good to me as well.

wesselb commented 2 years ago

Perfect! Merging this now, then. Thanks, @patel-zeel! :)

patel-zeel commented 2 years ago

Great! Thank you @wesselb :)