cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.55k stars 557 forks source link

[Bug] _permute_batch has wrong behavior in some instances with LazyEvaluatedKernelTensor #1853

Open Balandat opened 2 years ago

Balandat commented 2 years ago

🐛 Bug

In some situations _permute_batch when operating on a LazyEvaluatedKernelTensor that is batched will not properly permute dimensions but do some kind of expand.

To reproduce

I tried hard to find a minimal example, but so far I've only been able to reliably reproduce this when using a SAASBO model in Ax. Full repro here: permute_batch_shape_issue_repro.ipynb.txt - this is on current master of gpytorch, botorch, and Ax.

Stack trace/error message

The offending part is this one:

ipdb> covar_blocks_lazy.lazy_tensors[0].shape
torch.Size([1, 4, 100, 100])
ipdb> covar_blocks_lazy.lazy_tensors[0]._permute_batch(1, 0).shape
torch.Size([4, 4, 100, 100])

Expected Behavior

Rather than lazy tensor of shape torch.Size([4, 4, 100, 100]) this should return a tensor of size torch.Size([4, 1, 100, 100]), i.e., perform a proper permutation of batch dimensions.

saitcakmak commented 2 years ago

I tried to dig into it a bit with my minimal knowledge of lazy tensors.

https://github.com/cornellius-gp/gpytorch/blob/70eb9f9561bc3ce1f5c00c36e1d865c6b78fca44/gpytorch/lazy/lazy_tensor.py#L190

This line takes in 2 4 x 1 x 100 x s (s takes values from 5, 6, 100) tensors as components, and the kwargs has the kernel and 'last_dim_is_batch': False. The result of the operation has res.shape = 4 x 4 x 100 x 100.

Looking into what happens in res.shape call, it is retrieved from the cache even at the first call. Ignoring that and going into res._size(), it goes into the else part of this check due to kernel.batch_shape=4, which produces the shape 4 x 4 x 100 x 100 (the first expected_shape is 4 x 1 x 100 x 100, lines 250-260 makes it 4 x 4 x 100 x 100):

https://github.com/cornellius-gp/gpytorch/blob/70eb9f9561bc3ce1f5c00c36e1d865c6b78fca44/gpytorch/lazy/lazy_evaluated_kernel_tensor.py#L237-L260

Balandat commented 2 years ago

Hmm I see what's going on, this is pretty nasty. Basically we're properly permuting the batch dimensions of the input tensors, but we don't do anything to the batch_shape of the kernel itself. As a result we broadcast things in the shape computation. Basically, in l. 251 we're broadcasting the expected size from the inputs ([4, 1, 100, 100]) against a kernel batch shape of [4], which gets broadcasted to [4, 4], causing the issue.

What we ought to do is to permute the batch shape (and interpret [4] as [1, 4] by pre-padding ones), permute that to [1, 4], and then broadcast, which would get us to the correct [4, 1, 100, 100].

I'll see if I can write a custom method for this, though I am a bit scared about the complexities that we might run into if we start permuting dimensions of kernels...

Balandat commented 2 years ago

@sdaulton, @dme65 I am not sure what happens exactly when this is not hitting the max eager threshold, but I wonder if it could be that we're unnecessarily broadcasting even in that case, and as a results compute the wrong thing (or the same thing multiple times unnecessarily).

Balandat commented 2 years ago

@gpleiss, @jacobrgardner do either of you know whether this has ever worked properly in this setting? Or is batch-evaluating a batched model with different batch dimensions and then lazily handling the resulting posterior covariance just not something we've run into / considered before?

jacobrgardner commented 2 years ago

@Balandat Since batch shapes of an arbitrary kernel are probably hard to modify on the fly for this, maybe we should special case operations that modify the batch shape to only modify the size of the LEKT, and then be actually applied only when we evaluate_kernel? I know I ran into a similarish issue before with kernel batch handling which is why kernels define __getitem__, but I don't think we've seen this specific issue. If we want the solution you proposed, I think __getitem__ on Kernel is a reasonable template to start with probably

Balandat commented 2 years ago

Hmm yeah interesting idea. I think there might be a few things to be careful about - e.g. we will need to have things like .size() and .shape() return the already permuted shape even if the actual tensor itself hasn't been permuted yet. But might not be too bad, I'll give it a try.