tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

`GaussianProcess` sample does not work with new index points #837

Open fdtomasi opened 4 years ago

fdtomasi commented 4 years ago

A minimal example is the following:

import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
tfk = tfp.math.psd_kernels
index_points = np.array([1., 2, 3])[:, None]
other_index_points = np.array([1., 2])[:, None]

gp = tfd.GaussianProcess(tfk.ExponentiatedQuadratic(), index_points=index_points)
# gp.sample(index_points=other_index_points)  # raise ValueError: Tensor's shape (2,) is not compatible with supplied shape (3,)

The GaussianProcess class does implement _sample_n to take into account new index points:

  def _sample_n(self, n, seed=None, index_points=None):
    return self.get_marginal_distribution(index_points).sample(n, seed=seed)

Following the chain of calls, sample calls _call_sample_n, which calls _sample_n (successful) and then _set_sample_static_shape (in distribution.py -- call is unsuccessful, because it calls self.event_shape -- hence the information on the new index points is lost), which fails at setting the shape of the sample.

A workaround is to use

gp.get_marginal_distribution(index_points=other_index_points).sample()  # works fine

and that it's what I have been doing until now. However, I wanted to extend the GaussianProcess class and override only the _sample_n method instead of the general sample to retain the various checks, but I could not find a way around (instead of reimplementing sample and avoiding the checks).

Needless to say, if we define

other_index_points = np.array([1., 2, 4])[:, None]

then everything works fine (the index points are different, but the event_shape is the same).

Is there something I missed?

csuter commented 4 years ago

Thanks for filing this. I vaguely realized that sampling with index_points was broken (I tried it once and found what you found) but I hadn't dug much into the underlying issues. I think the fix here actually involves changes to the base distribution class, which has potential to be somewhat involved. We'd basically need to admit **kwargs in the event_shape method, and plumb it through from sample. For consistency we'd need to do the same for event_shape_tensor. Finally, we'd need to update GP to accept the index_points kwarg to its impl of these methods.

csuter commented 4 years ago

Hm, actually event_shape is a property, not a method, so it can't accept **kwargs. This needs some more thought.