cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.59k stars 562 forks source link

MVN with tensor __add__ and __mul__ [Feature Request] #1077

Open sebascuri opened 4 years ago

sebascuri commented 4 years ago

🚀 Feature Request

Currently, MultivariateNormal has implemented add and mul methods with either (i) another multivariate normal , or (ii) floats/ints. It would be useful to accept tensors (with compatible dimensions).

Motivation

Say I have MVN with mean [3, 2] and some covariance. Say I want to shift it by [-1, 2]. Currently, I can only do so by creating a new MVN object as the add method forces me to create a new MVN object with some pos. def. covariance matrix.

Say I have a Multitask MVN, with some mean per task and covariance per task. Say I would like to shift each task and rescale each task by different values. Currently, this is not supported.

This is useful for multi-output regression, where each output could have potentially different means and scales. For training, the input/outputs are normalized, but during prediction, these should be re-scaled.

Pitch

Describe the solution you'd like I'd like a solution that either: (i) for each task of the Multitask MVN, returns me a MVN that I could shift and scale. (This could be useful by itself).

(ii) for a Multitask MVN, it accepts a tensor both in the add and mul methods

Describe alternatives you've considered

Currently my code looks like this:

def shift_mvn(mvn: Union[MultivariateNormal, MultitaskMultivariateNormal], mean: Tensor, scale: Tensor):
r"""Shift and Scale a Multivariate Normal

Create a New Multivariate Normal with mean:
.. math :: \mathcal{N}(\mu  * scale + mean, scale * \Sigma * scale^\top)
where the original mvn is 
.. math :: \mathcal{N}(\mu,  \Sigma)

Parameters
-----------
mvn: MultivariateNormal
    Multivariate normal with mean of size `n x t' or `n' and covariance `nt x nt' or `n x n'. 
mean: Tensor
    Tensor to add to the mean of the mvn with size `t', `n x t', or 'n'.  
scale: Tensor
    Tensor to scale the mvn of size `t', 'nt x nt', 'n', or 'n x n'. 
"""
    mu = mvn.mean
    sigma = mvn.lazy_covariance_matrix
    if not isinstance(mvn, MultitaskMultivariateNormal):  # Simple MVN. 
        return MultivariateNormal(mu * scale + mean,
                                  covariance_matrix=sigma * scale ** 2)
    num_points, num_tasks = mvn.mean.shape
    mvns = []
    for i in range(num_tasks):
       # It could be useful to have a method from MultitaskMultivariateNormal that is mvn.task(task_nr: int) -> MultivariateNormal: that returns the MVN associated to each task. 
        mean_ = mu[..., i]
        cov_ = sigma[i * num_points:(i + 1) * num_points,
               i * num_points:(i + 1) * num_points]
        mvns.append(shift_mvn(MultivariateNormal(mean_, cov_),
                              mean[..., i],
                              variance[..., i]))
    return MultitaskMultivariateNormal.from_independent_mvns(mvns)

I think something like

class MultivariateNormal():
...
    def __add__(self, other):
        if isinstance(other, MultivariateNormal):
            return self.__class__(
                mean=self.mean + other.mean,
                covariance_matrix=(self.lazy_covariance_matrix + other.lazy_covariance_matrix),
            )
        elif isinstance(other, int) or isinstance(other, float) or isinstance(other, torch.Tensor):
            return self.__class__(self.mean + other, self.lazy_covariance_matrix)
        else:
            raise RuntimeError("Unsupported type {} for addition w/ MultivariateNormal".format(type(other)))

    def __mul__(self, other):
        if not (isinstance(other, int) or isinstance(other, float)) or isinstance(other, torch.Tensor):
            raise RuntimeError("Can only multiply by scalars or tensors")
        if other == 1:
            return self
# TODO: size checking, wether to do dot_product or matrix product. 
        return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other ** 2))

class MultitaskMultivariateNormal(MultivariateNormal):
def get_task(task_nr):
    num_points, num_tasks = self._output_shape
    mean_ = self.mean[..., i]
    cov_ = self.lazy_covariance_matrix[i * num_points:(i + 1) * num_points,  i * num_points:(i + 1) * num_points]
    return MultivariateNormal(mean_, cov_)

should work

Are you willing to open a pull request? (We LOVE contributions!!!) Yes

jacobrgardner commented 4 years ago

@sebascuri Given that the existence of __add__ and __mul__ is already a deviation from the PyTorch base clase (which doesn't support any form of + or - on distributions), I think this could be reasonable. Right now, since A + B semantically means "treat A and B as random variables and add them," being able to also do s*A + m makes sense to me.

In my opinion, feel free to open a PR since you mention you're willing to!