Open Balandat opened 2 years ago
I tried to dig into it a bit with my minimal knowledge of lazy tensors.
This line takes in 2 4 x 1 x 100 x s
(s takes values from 5, 6, 100
) tensors as components
, and the kwargs
has the kernel and 'last_dim_is_batch': False
. The result of the operation has res.shape = 4 x 4 x 100 x 100
.
Looking into what happens in res.shape
call, it is retrieved from the cache even at the first call. Ignoring that and going into res._size()
, it goes into the else
part of this check due to kernel.batch_shape=4
, which produces the shape 4 x 4 x 100 x 100
(the first expected_shape
is 4 x 1 x 100 x 100
, lines 250-260 makes it 4 x 4 x 100 x 100
):
Hmm I see what's going on, this is pretty nasty. Basically we're properly permuting the batch dimensions of the input tensors, but we don't do anything to the batch_shape
of the kernel itself. As a result we broadcast things in the shape computation. Basically, in l. 251 we're broadcasting the expected size from the inputs ([4, 1, 100, 100]
) against a kernel batch shape of [4]
, which gets broadcasted to [4, 4]
, causing the issue.
What we ought to do is to permute the batch shape (and interpret [4]
as [1, 4]
by pre-padding ones), permute that to [1, 4]
, and then broadcast, which would get us to the correct [4, 1, 100, 100]
.
I'll see if I can write a custom method for this, though I am a bit scared about the complexities that we might run into if we start permuting dimensions of kernels...
@sdaulton, @dme65 I am not sure what happens exactly when this is not hitting the max eager threshold, but I wonder if it could be that we're unnecessarily broadcasting even in that case, and as a results compute the wrong thing (or the same thing multiple times unnecessarily).
@gpleiss, @jacobrgardner do either of you know whether this has ever worked properly in this setting? Or is batch-evaluating a batched model with different batch dimensions and then lazily handling the resulting posterior covariance just not something we've run into / considered before?
@Balandat Since batch shapes of an arbitrary kernel are probably hard to modify on the fly for this, maybe we should special case operations that modify the batch shape to only modify the size of the LEKT, and then be actually applied only when we evaluate_kernel
? I know I ran into a similarish issue before with kernel batch handling which is why kernels define __getitem__
, but I don't think we've seen this specific issue. If we want the solution you proposed, I think __getitem__
on Kernel is a reasonable template to start with probably
Hmm yeah interesting idea. I think there might be a few things to be careful about - e.g. we will need to have things like .size()
and .shape()
return the already permuted shape even if the actual tensor itself hasn't been permuted yet. But might not be too bad, I'll give it a try.
🐛 Bug
In some situations
_permute_batch
when operating on aLazyEvaluatedKernelTensor
that is batched will not properly permute dimensions but do some kind ofexpand
.To reproduce
I tried hard to find a minimal example, but so far I've only been able to reliably reproduce this when using a SAASBO model in Ax. Full repro here: permute_batch_shape_issue_repro.ipynb.txt - this is on current master of gpytorch, botorch, and Ax.
Stack trace/error message
The offending part is this one:
Expected Behavior
Rather than lazy tensor of shape
torch.Size([4, 4, 100, 100])
this should return a tensor of sizetorch.Size([4, 1, 100, 100])
, i.e., perform a proper permutation of batch dimensions.