Open HuangYaowei opened 4 years ago
Sorry for the slow reply @HuangYaowei - in general, it is very likely that additive decompositions are slower than non-additive decompositions (up to maybe, let's say 20,000 data points). This is because GPyTorch can take advantage of parallelism when there isn't a decomposition.
However, it's also worth noting that your implementation can be fixed :) AdditiveStructureKernel
is designed to be a parallel object. The way you have currently constructed your kernel:
for i in range(27):# range(self.n_sub): # 27
# get active dimensions for each subspace_id
subspace_dim = subspace_dim_list[i]
kern_i = MaternKernel(lengthscale_constraint=lengthscale_constraint, ard_num_dims=len(subspace_dim), nu=2.5, active_dims=subspace_dim)
if i == 0:
kern = kern_i
else:
kern += kern_i
covar_module = AdditiveStructureKernel(base_kernel=kern, num_dims=train_x.shape[1])
^^ This would make all the 27 kernels operate in series. The proper way to use AdditiveStructureKernel
is
covar_module = AdditiveStructureKernel(base_kernel=MaternKernel(), num_dims=train_x.shape[1])
The AdditiveStructureKernel wrapper converts the MaternKernel into a batch of MaternKernels (one for each dimension) and then sum them up. By using batch operations we can get GPU parallelism. With the code you wrote (having a separate kernel object for each dimension) there's no way to use GPU parallelism.
Thanks for your reply @gpleiss . Yes, I have also found that additive decompositions are slower than non-additive decompositions. However, I have an extra question, since we all suppose the GPU runs faster than CPU, however, when I compare the speed with Gpytorch and GPy, I found that in the high dimensional space the Gpytorch did not run faster than GPy, have you compared the speed with other responsories? How can I improve my speed in Gpytorch? The speed comparison is on the environment of 1 GPU and 28 CPUs. I have already evaluated 30 points and want to inference the next point by Matern52.
What exactly are you comparing here? Can you provide code examples?
Hi @gpleiss, we are having some difficulties with the additive GPs (even following your AdditiveStructureKernel
example), with the additive kernel taking something like 10x longer to fit (w/ 9 dims) compared with full RBF, and just slightly slower to predict w/, for PairwiseGPs in BoTorch. I have a minimal repro using vanilla regression (ScaleKernel(RBFKernel()), which results in a fitting time that is ~2.4 slower, and prediction time that is about 74% slower. Is this expected?
How much data? Can you share the repro?
Here I would like to train a GP model on a very high dimension X, I will first decompose the X into 27 subspace_dim and then uses the addition of 27 MaternKernels as covar_module, however, the speed is even slower than ScaleKernel of origin X without decomposition, what should I do?
Code snippet to reproduce
However, as I run the code on the single GPU, it needs 384s to run. If I use
it only need 58s, but my original thought is to accelerate it to be faster than 58s, like 10s.
Have I implemented the AdditiveStructureKernel right? Or what should I do? Thank you very much.
By the way, I followed what GPy do, I will also paste their code here