cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

Extracting gradients of the log-marginal likelihood (w.r.t original hypers) evaluate at an arbitary point #1125

Open vr308 opened 4 years ago

vr308 commented 4 years ago

There is a way to extract the gradient vector w.r.t the "raw" or transformed hypers but I was wondering if I could obtain the gradients w.r.t the original hypers..

    n_restarts = 5
    num_steps = 1000
    losses = torch.zeros(n_restarts, num_steps)

    for i in np.arange(n_restarts):

         likelihood = gpytorch.likelihoods.GaussianLikelihood()
         model_ml = SpectralMixtureGPModel(train_x, train_y, likelihood, 2)

         model_ml.train()
         likelihood.train()
         # Use the adam optimizer
         optimizer = torch.optim.Adam(model_ml.parameters(), lr=0.05)
         # "Loss" for GPs - the marginal log likelihood
         mll_ml = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model_ml)

         for j in range(num_steps):
             optimizer.zero_grad()
             output = model_ml(train_x)
             loss = -mll_ml(output, train_y)
             losses[i, j] = loss.item()
             loss.backward()
             # Extracting gradients wrt raw hypers
             raw_mixture_mean_grad = model_ml.covar_module.raw_mixture_means.grad
             if (j%100 == 0):
                 print('Iter %d/%d - Loss: %.3f' % (j + 1, num_steps, loss.item()))
             optimizer.step()

The use case is to visualise the gradient field superimposed on the negative log marginal likelihood surface like below:

adam_lr_001

jacobrgardner commented 4 years ago

Hmm, this is actually surprisingly challenging to do right now because every time you call one of the getters (e.g., kernel.outputscale, it returns a newly transformed version of the raw outputscale.

Maybe one thing we could do is have these getters cache the transformed version, and clear the cache whenever backward is called? This would let you do something like torch.autograd.grad(mll, model.covar_module.base_kernel.lengthscale) and have it actually work.

Let me prototype that for lengthscales and see if it causes any problems.