cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.58k stars 562 forks source link

torch.gather does not work with Batch GPs #1255

Open Akella17 opened 4 years ago

Akella17 commented 4 years ago

I am using the batch feature to deploy multiple independent GPs, but the loss can only be computed for one GP heads per training example. Here is how I implemented it:

# inputs dimension: num_data * i
# predictions.mean dimension: num_GP_heads * num_data
# predictions.covariance_matrix dimension: num_GP_heads * num_data * num_data
# selection dimension: num_data
# targets dimension: num_data
predictions = model(inputs)
# modified_predictions_mean dimension: num_data
# modified_predictions_covar dimension: num_data * num_data
modified_predictions_mean = predictions.mean.gather(0, selection.unsqueeze(0)).squeeze(0)
modified_predictions_covar = predictions.covariance_matrix.gather(0, selection.unsqueeze(0).unsqueeze(-1).repeat(1,1,data_batch_size)).squeeze(0)
modified_predictions = gpytorch.distributions.MultivariateNormal(modified_predictions_mean, modified_predictions_covar)

loss = -mll(modified_predictions, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()

However, the loss has num_GP_heads shape instead of num_data shape. Moreover, even when the num_data is 1 (meaning that loss needs to be computed for only one of the GP heads), the loss is num_data dimensional and non zero for all GP heads. Is there anything wrong with my implementation?

jacobrgardner commented 4 years ago

I'm somewhat confused -- why would the GP MLL loss have a num_data shape? The MLL doesn't decompose as a sum over the individual data points or anything, so there's not really a good analogue of like a "per-data point" loss for GPs.

Akella17 commented 4 years ago

@jacobrgardner Thanks for the quick response. I agree with what you are saying. But what I don't understand is why is the loss non-zero for all the GP heads when num_data = 1 (only one GP head is used/gathered for computing the loss).

In other words, I was expecting the loss to be zero for all other GP heads except for the one that was gathered for computing the loss.