Open fdtomasi opened 4 years ago
Thanks for filing this. I vaguely realized that sampling with index_points was broken (I tried it once and found what you found) but I hadn't dug much into the underlying issues. I think the fix here actually involves changes to the base distribution class, which has potential to be somewhat involved. We'd basically need to admit **kwargs in the event_shape method, and plumb it through from sample. For consistency we'd need to do the same for event_shape_tensor. Finally, we'd need to update GP to accept the index_points kwarg to its impl of these methods.
Hm, actually event_shape is a property, not a method, so it can't accept **kwargs. This needs some more thought.
A minimal example is the following:
The
GaussianProcess
class does implement_sample_n
to take into account new index points:Following the chain of calls,
sample
calls_call_sample_n
, which calls_sample_n
(successful) and then_set_sample_static_shape
(indistribution.py
-- call is unsuccessful, because it callsself.event_shape
-- hence the information on the new index points is lost), which fails at setting the shape of the sample.A workaround is to use
and that it's what I have been doing until now. However, I wanted to extend the
GaussianProcess
class and override only the_sample_n
method instead of the generalsample
to retain the various checks, but I could not find a way around (instead of reimplementingsample
and avoiding the checks).Needless to say, if we define
then everything works fine (the index points are different, but the
event_shape
is the same).Is there something I missed?