cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.57k stars 560 forks source link

Slow marginal posterior computations #732

Open Balandat opened 5 years ago

Balandat commented 5 years ago

Sometimes we're only interested in the marginal posterior rather than the joint posterior. One way to compute this if we have a batch of b x q x d points (where b is the batch dimension, q is the number of points in each batch, and d is the dimension of the features) is to batch-evaluate the model on the q x b x 1 x d permuted tensor, in which case gpytorch computes a batched mvn of event_shape=1 with batch shape q x b.

You'd expect this to be significantly faster than computing the joint posterior and taking the diagonal to get the marginal variances. But, in fact, it's actually significantly slower, as the following nb shows: slow_marginal_posterior.ipynb.txt

Not quite sure what’s going on here. Possibly the marginal sampling doesn’t actually use a different code path, and still naively computes a root decomposition of a batched 1x1 matrix using an iterative algorithm. There could also be a bunch of overhead in the batched LazyTensor indexing.

Any thoughts on this? Is there a better way to get the marginal posterior?

jacobrgardner commented 5 years ago

@Balandat Just getting back to work from being at ICML, sorry for the delay on this.

Looking at your code, I think there's some misunderstanding about what model(test_x) is doing. In general, the predictive "covariance" matrix we return is a LazyTensor where none of the heavy math involved in getting the joint is done until you actually ask for off diagonal elements. Thus, the most efficient way I know to get the marginal distribution in GPyTorch is exactly what you are doing in get_from_joint.

Treating everything as batch mode computations seems like it's just always going to be slower because the underlying LAPACK and MAGMA calls probably just handle batch dimensions in worse ways.

Balandat commented 5 years ago

@jacobrgardner I do understand the lazy evaluation of the covariance matrix, my description above was incorrect. (btw, turns out that in this data regime getting the full joint is basically indistinguishable from getting the marginal in terms of wall time: slow_marginal_posterior_full_joint.ipynb.txt).

Treating everything as batch mode computations seems like it's just always going to be slower because the underlying LAPACK and MAGMA calls probably just handle batch dimensions in worse ways.

Either that, or there are some inefficiencies in the batching acrobatics. Let me try to do some profiling to see where the time is being spent.

jacobrgardner commented 5 years ago

I would buy that they are the same on data that small.

There are two fundamental differences between full and diagonal mode (e.g., calling covariance_matrix vs variance):

  1. In full mode, we call LazyEvaluatedKernelTensor.evaluate on a t x t matrix (K**), and in diag mode we call .diag on it instead.
  2. In full mode, we call MatmulLazyTensor.evaluate on a (t x n) * (n x t) lazy tensor (K*(K+sI)^{-1}K*'), and in diag mode we call diag on it, which will do an elementwise multiply instead of a matmul.

2 in particular clearly has a vastly different running time in full vs diag mode, but it probably doesn't matter on modern hardware. In fact, at very low data settings (e.g., t very small), it may be the case that neither 1 nor 2 are the actual bottleneck.