Open Balandat opened 5 years ago
@Balandat Just getting back to work from being at ICML, sorry for the delay on this.
Looking at your code, I think there's some misunderstanding about what model(test_x)
is doing. In general, the predictive "covariance" matrix we return is a LazyTensor
where none of the heavy math involved in getting the joint is done until you actually ask for off diagonal elements. Thus, the most efficient way I know to get the marginal distribution in GPyTorch is exactly what you are doing in get_from_joint
.
Treating everything as batch mode computations seems like it's just always going to be slower because the underlying LAPACK and MAGMA calls probably just handle batch dimensions in worse ways.
@jacobrgardner I do understand the lazy evaluation of the covariance matrix, my description above was incorrect. (btw, turns out that in this data regime getting the full joint is basically indistinguishable from getting the marginal in terms of wall time: slow_marginal_posterior_full_joint.ipynb.txt).
Treating everything as batch mode computations seems like it's just always going to be slower because the underlying LAPACK and MAGMA calls probably just handle batch dimensions in worse ways.
Either that, or there are some inefficiencies in the batching acrobatics. Let me try to do some profiling to see where the time is being spent.
I would buy that they are the same on data that small.
There are two fundamental differences between full and diagonal mode (e.g., calling covariance_matrix
vs variance
):
LazyEvaluatedKernelTensor.evaluate
on a t x t
matrix (K**
), and in diag mode we call .diag
on it instead.MatmulLazyTensor.evaluate
on a (t x n) * (n x t)
lazy tensor (K*(K+sI)^{-1}K*'
), and in diag mode we call diag
on it, which will do an elementwise multiply instead of a matmul.2 in particular clearly has a vastly different running time in full vs diag mode, but it probably doesn't matter on modern hardware. In fact, at very low data settings (e.g., t very small), it may be the case that neither 1 nor 2 are the actual bottleneck.
Sometimes we're only interested in the marginal posterior rather than the joint posterior. One way to compute this if we have a batch of
b x q x d
points (whereb
is the batch dimension,q
is the number of points in each batch, andd
is the dimension of the features) is to batch-evaluate the model on theq x b x 1 x d
permuted tensor, in which case gpytorch computes a batched mvn of event_shape=1 with batch shapeq x b
.You'd expect this to be significantly faster than computing the joint posterior and taking the diagonal to get the marginal variances. But, in fact, it's actually significantly slower, as the following nb shows: slow_marginal_posterior.ipynb.txt
Not quite sure what’s going on here. Possibly the marginal sampling doesn’t actually use a different code path, and still naively computes a root decomposition of a batched 1x1 matrix using an iterative algorithm. There could also be a bunch of overhead in the batched LazyTensor indexing.
Any thoughts on this? Is there a better way to get the marginal posterior?