cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.52k stars 553 forks source link

[Feature Request] Memory-efficient additive kernels #748

Open idelbrid opened 5 years ago

idelbrid commented 5 years ago

🚀 Feature Request

Motivation

Hi, I'm working on using additive kernels on moderately large datasets, and I've been running into out-of-memory errors due to the way that additive kernels are handled. That is, if we have a d-dimensional data set and have an additive kernel for each dimension, we end up storing d separate N x N kernel matrices, or, using AdditiveStructureKernel, a d x N x N tensor. Using a 12GB GPU, this runs out of memory at for the "pol" UCI data set with N=13500, d=26.

Pitch

In concept, a fully additive kernel shouldn't be significantly more computationally or memory expensive than a non-additive (RBF) kernel. An iterative computation should only take N x N memory instead of d x N x N. My stop-gap solution is to use a custom forward/backward function that uses loops instead of expanding to larger tensors. This is pasted below. Maybe you have better ideas about how it could be integrated into the codebase.

Using this function is much faster than using an AdditiveStructureKernel or AdditiveKernel if it keeps you from checkpointing the kernel.

That being said, I don't think it is optimized well - if I use cdist, it's very slow but doesn't use much memory at all, whereas if I manually compute distances, it's much faster but uses much more memory. This implementation also isn't very modular (only RBF permitted, and it isn't implemented to run in batch).

Any thoughts? I'd be willing to make PR if you think something like this would be a good idea. Any suggestions for performance improvement here would be helpful too.

Additional context

Implementation:

class GAMFunction(torch.autograd.Function):
    """Function to compute sum of RBF kernels with efficient memory usage (Only O(nm) memory)
    The result of forward/backward are n x m matrices, so we can get away with only allocating n x m matrices at a time
        instead of expanding to a d x n x m matrix.
    Does not support batch mode!
    """
    @staticmethod
    def forward(ctx, x1, x2, lengthscale):
        n, d = x1.shape
        m, d2 = x2.shape
        if d2 != d:
            raise ValueError("Dimension mismatch")
        x1_ = x1.div(lengthscale)  # +n x d vector
        x2_ = x2.div(lengthscale)  # +m x d vector
        ctx.save_for_backward(x1, x2, lengthscale)  # maybe have to change?
        kernel = torch.zeros(n, m, dtype=x1_.dtype, device=x1.device)  # use accumulator+loop instead of expansion
        for i in range(d):
            # does cdist still create a new n x m tensor in the graph? Any way to avoid allocating the memory?
            # Should just create temporary n x m tensor and add it to the accumulator.
            with torch.no_grad():
                # kernel.add_(torch.cdist(x1_[:, i:i+1], x2_[:, i:i+1]).pow_(2).div_(-2).exp_())
                # kernel.add_((x1_[:, i].expand(m, -1).t() - x2_[:,i].expand(n, -1)).pow_(2).div_(-2).exp_())
                # The cdist implementation is dramatically slower! But the above is too data hungry somehow?
                #   it must be due to the double 'expand's. The below is almost as fast and saves memory.
                kernel.add_((x1_[:, i].view(n, 1) - x2_[:, i].expand(n, -1)).pow_(2).div_(-2).exp_())
        return kernel

    @staticmethod
    def backward(ctx, grad_output):
        x1, x2, lengthscale = ctx.saved_tensors
        x1_ = x1.div(lengthscale)  # probably could just save the scaled x1/x2 tensors from forward
        x2_ = x2.div(lengthscale)
        n, d = x1.shape
        m, d2 = x2.shape
        num_l = torch.numel(lengthscale)  # support ARD/single lengthscale
        lengthscale_grad = torch.zeros_like(lengthscale)
        x1_grad = torch.zeros_like(x1) if x1.requires_grad else None
        x2_grad = torch.zeros_like(x2) if x2.requires_grad else None

        # Again, use accumulators instead of expansion. Less computationally efficient, but more memory efficient.
        with torch.no_grad():
            for i in range(d):
                sq_dist = (x2_[:, i].expand(n,-1) - x1_[:, i].view(n, 1)).pow_(2)
                # sq_dist = torch.cdist(x1_[:, i:i + 1], x2_[:, i:i + 1]).pow_(2)
                Delta_K = sq_dist.div(-2).exp_().mul_(grad_output)  # Reused below.
                idx = i if num_l > 1 else 0
                lengthscale_grad[idx].add_(sq_dist.mul_(Delta_K).sum().div(lengthscale[idx]))

                if x1.requires_grad or x2.requires_grad:
                    Delta_K_diff = (x2_[:, i].expand(n, -1) - x1_[:, i].view(n, 1)).mul_(Delta_K)
                    if x1.requires_grad:
                        x1_grad[:, i] = Delta_K_diff.sum(dim=1).div_(lengthscale[idx])  # sum over rows/x2s
                    if x2.requires_grad:
                        x2_grad[:, i] = -Delta_K_diff.sum(dim=0).div_(lengthscale[idx])  # sum over columns/x1s
        return x1_grad, x2_grad, lengthscale_grad

Tests:

# pol
trainX = trainX.to('cuda:0')
trainY = trainY.to('cuda:0')
def train(kernel):
    model = ExactGPModel(trainX, trainY, GaussianLikelihood(), kernel).to('cuda:0')
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to('cuda:0')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    model.train()
    mll.train()
    with gpytorch.settings.fast_computations(), \
          gpytorch.settings.skip_posterior_variances(True), \
          gpytorch.settings.memory_efficient(True):
        for i in range(20):
            output = model(trainX)
            loss = -mll(output, trainY)
            loss.backward()
            optimizer.step()
mem_eff_kernel = ScaleKernel(MemoryEfficientGamKernel()).to('cuda:0')
%time train(mem_eff_kernel)
CPU times: user 23.1 s, sys: 6.56 s, total: 29.7 s
Wall time: 30 s
batch_add_kernel = ScaleKernel(AdditiveStructureKernel(RBFKernel(), trainX.shape[1])).to('cuda:0')
with gpytorch.beta_features.checkpoint_kernel(2000):
    %time train(batch_add_kernel)
CPU times: user 5min 7s, sys: 1min 31s, total: 6min 39s
Wall time: 6min 43s
subkernels = [RBFKernel(active_dims=[i]) for i in range(trainX.shape[1])]
add_kernel = ScaleKernel(AdditiveKernel(*subkernels)).to('cuda:0')
with gpytorch.beta_features.checkpoint_kernel(2500):
    %time train(add_kernel)
CPU times: user 4min 51s, sys: 1min 45s, total: 6min 37s
Wall time: 6min 41s
jacobrgardner commented 5 years ago

AdditiveStructureKernel and ProductStructureKernel intentionally use a factor d more memory because it handles the structure in batch mode rather than a for loop. This turns out to be extremely important for speed in two cases: (1) SV-DKL with additive structure in the last layer as in the paper, and (2) SKIP.

AdditiveKernel and ProductKernel, however, do compute their kernel results in for loops and should not use an additional factor ofd memory unless the underlying kernels return things other than tensors or non lazy tensors.

If you are running out of memory with AdditiveKernel, could you try running your code in a gpytorch.settings.lazily_evaluate_kernels(False) context and seeing if that solves the memory issue? In my opinion, AdditiveKernel (e.g., sans Structure) using the factor of d more memory would be a bug not intended behavior.

idelbrid commented 5 years ago

Thanks for the reply. Using gpytorch.settings.lazily_evaluate_kernels(False) doesn't resolve the memory issue with AdditiveKernel. Using this setting actually seems to use more memory, as I now run out of memory with a smaller checkpoint size (and without checkpointing).

subkernels = [RBFKernel(active_dims=[i]) for i in range(trainX.shape[1])]
add_kernel = ScaleKernel(AdditiveKernel(*subkernels)).to('cuda:0')
with gpytorch.beta_features.checkpoint_kernel(2500), gpytorch.settings.lazily_evaluate_kernels(False):
    train(add_kernel)  # CUDA out of memory

I'm not entirely sure why this is the case. One guess is that since all kernels are computed separately, maybe a copy of the kernel tensors might be cached during the forward pass for the backward pass.

Also, I forgot to say, this is using gpytorch version 0.3.2, pytorch version 1.1.0.

gpleiss commented 5 years ago

@idelbrid @jacobrgardner - I think we could add something similar to what you're proposing - where we accumulate the kernel when memory is an issue. I'm not sure exactly where it would fit within the GPyTorch architecture though... maybe AdditiveKernel could have a "memory efficient" option that would turn off lazy evaluation and accumulate the kernel matrix?