cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.53k stars 556 forks source link

trivial optimization issue #1754

Open cagatayyildiz opened 3 years ago

cagatayyildiz commented 3 years ago

Hi,

I have been implementing decoupled sampling of GPs and bumped into a rather strange issue. I tried to simplify it by the following piece of code, which is a sparse GP example in which we try to minimize a stupid loss related to the variational posterior:

import torch, gpytorch

class ApproximateGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(-1))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

model = ApproximateGPModel(torch.randn(20))
model.train()

opt = torch.optim.Adam(model.parameters())
for i in range(2):
    opt.zero_grad()
    # model.train() # this needs to be called in each iteration, otherwise the following error occurs
    loss = -model.variational_strategy.variational_distribution.rsample(torch.Size([10])).sum()
    loss.backward()
    print('iter={:d}, \tloss={:.3f}'.format(i,loss.item()))
    opt.step()

Now, if we do not call model.train() before each gradient computation, I get the following error after the first iteration:

Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

Why is this happening? I guess either this is a trivial bug or I miss something quite trivial.

Thanks, Cagatay.

wjmaddox commented 3 years ago

Yeah, this is because of the way the variational distribution is stored (e.g. caching) --- haven't quite figured out if it is a bug yet.

A couple of workarounds that should be fine, not entirely sure which would be preferred:

opt = torch.optim.Adam(model.parameters())
for i in range(10):
    opt.zero_grad()

    variational_dist = model.variational_strategy.variational_distribution
    loss = -variational_dist.rsample(torch.Size([10])).sum()
    loss.backward()
    print('iter={:d}, \tloss={:.3f}'.format(i,loss.item()))
    del model.variational_strategy._memoize_cache # deletes the memoize cache which is what's building up
    opt.step()
opt = torch.optim.Adam(model.parameters())
for i in range(10):
    opt.zero_grad()

    # use the private method to regenerate the variational distribution
    variational_dist = model.variational_strategy._variational_distribution()
    loss = -variational_dist.rsample(torch.Size([10])).sum()
    loss.backward()
    print('iter={:d}, \tloss={:.3f}'.format(i,loss.item()))
    opt.step()
gpleiss commented 2 years ago

I think this counts as a bug :) In variational_strategy, we clear the cache (which stores the current variational distribution) when __call__ is called, but we do not clear otherwise. Since you are not calling your model, the cache doesn't get cleared.

I'm not sure exactly how to best deal with this issue. We could do some form of less aggressive caching, or smarter clearing of the cache. @wjmaddox's workaround should be good in the meantime.