cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.56k stars 559 forks source link

Speed issue: combining kernels in Hadamaard multi-task setting #893

Closed tadesautels closed 4 years ago

tadesautels commented 5 years ago

Hi All,

I am trying to build models for doing multi-task regression and have some key questions about speed at prediction time for these models. I'm looking for intuition from you all on why two seemingly similar models run at very different speeds.

In my application, not all tasks appear together; I therefore need capabilities like the Hadamard_multitask_gp_regression example notebook. FYI, I'm also using the multitask_fantasies branch with a relatively recent version of master merged into it.

On to my specifics: I have two model classes, A and B, both derived from gpytorch.models.ExactGP. In each case, the input to the kernels is run through (very similar) shallow neural networks to do some feature transformation (removed from the examples below).

Model class A has a kernel defined like so in its init method:

self.covar_module = gpytorch.kernels.RBFKernel()
self.task_covar_module = gpytorch.kernels.IndexKernel(
                num_tasks=num_tasks, rank=num_tasks-1)

and a forward method in which

def forward(self, x, i):
...
covar_x = self.covar_module(x)
covar_i = self.task_covar_module(i)
covar = covar_x.mul(covar_i)
return gpytorch.distributions.MultivariateNormal(mean_x, covar)

where mean_x is defined via the model's mean_module.

Similarly, B has the a kernel defined as follows in init:

self.covar_module = \
                gpytorch.kernels.RBFKernel(
                    active_dims=torch.tensor(stationary_dims)
                ) +\
                gpytorch.kernels.LinearKernel(
                    active_dims=torch.tensor(linear_dims)
                )

self.task_covar_module = gpytorch.kernels.IndexKernel(
                num_tasks=num_tasks, rank=num_tasks-1)

where stationary_dims and linear_dims split the (expanded) feature dimension, and a forward method:

def forward(self, x, i):
...
xall = torch.cat([xstationary, xlin], dim=1)
covar_x = self.covar_module(xall)
covar_i = self.task_covar_module(i)
covar = covar_x.mul(covar_i)
return gpytorch.distributions.MultivariateNormal(mean_x, covar)

At prediction time, the model has training data from O(10k) observations and is given a set of O(10k) targets (feature vectors pred_x and indices pred_i). I'm computing a single, big ol' multivariate normal over all of the targets, where I'm doing it as follows:

with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.max_root_decomposition_size(1e5):
        joint_mvn = model(pred_x, pred_i)

The max_root_decomposition_size is enormous and does make this already take a long time. It's necessary because I care quite a lot about the fine, across-task structure of the posterior, which (as far as I've been able to tell) tends to be resolved poorly with the default value.

The weird thing is that, when applied at prediction time, model B (with the additive kernel) takes MUCH longer to run. I unfortunately don't have access to exact timings at the moment, but model class A makes predictions in seconds in this setting, whereas model class B takes 15+ minutes. The only thing that seems plausible to me is that this is related to the additive structure of the kernel; it does seem odd that this is not optimized, however. I have noticed some similar problems with the product structure of the spatial and task covariance, which is part of what's led me to do a single, big covariance calculation (that's a different story).

Do any of you have intuition as to why model class B takes so much longer to predict than A does?

gpleiss commented 5 years ago

@tadesautels - i made a notebook (attached) to try to reproduce these results. I trained GP models with 2000 training data points with 4 features.

On CPU A and B take the same amount of time. On GPU B takes about twice as long as A. Is this consistent with the results that you're seeing?

Hadamard_Additive_Timing.ipynb.zip

gpleiss commented 5 years ago

Hadamard_Additive_Timing (1).ipynb.zip

Sorry - realize that didn't have timing results for predictions. I'm seeing the same speed for A and B on prediction. Can you reproduce?

On another note - if the max_root_decomposition_size = N, then there's not really a point to using fast_pred_var. You'll probably get more stable (and slightly faster) results without using that flag.

gpleiss commented 4 years ago

closing for now - reopen if there's still questions :)