CUQI-DTU / CUQIpy

https://cuqi-dtu.github.io/CUQIpy/
Apache License 2.0
48 stars 9 forks source link

Support burnthin in get_samples method #515

Closed nabriis closed 2 months ago

nabriis commented 2 months ago

Description:

Allow users to remove burnin after calling get_samples in when using Gibbs sampling more easily.

Suggested change. Overload the dict object as "JointSamples" object, which contains multiple samples objects. Enable method burnthin such that users can call:

sampler.get_samples().burnthin(200) # Should work in both single Samples object and JointSamples

The code would be this in _samples.py

class JointSamples(dict):
    """ An object used to store samples from joint distributions. 

    This object is a simple overload of the dictionary class to allow easy access to certain methods 
    of Samples objects without having to iterate over each key in the dictionary. 

    """

    def burnthin(self, Nb, Nt=1):
        """ Remove burn-in and thin samples for all samples in the dictionary. Returns a copy of the samples stored in the dictionary. """
        return JointSamples({key: samples.burnthin(Nb, Nt) for key, samples in self.items()})

    def __repr__(self) -> str: 
        return "CUQIpy JointSamples Dict:\n" + \
               "-------------------------\n\n" + \
               "Keys:\n {}\n\n".format(list(self.keys())) + \
               "Ns (number of samples):\n {}\n\n".format({key: samples.Ns for key, samples in self.items()}) + \
               "Geometry:\n {}\n\n".format({key: samples.geometry for key, samples in self.items()}) + \
               "Shape:\n {}\n\n".format({key: samples.shape for key, samples in self.items()})

The repr would return stuff like

CUQIpy JointSamples Dict:
-------------------------

Keys:
 ['x', 's']

Ns (number of samples):
 {'x': 250, 's': 250}

Geometry:
 {'x': _DefaultGeometry1D(128,), 's': _DefaultGeometry1D(1,)}

Shape:
 {'x': (128, 250), 's': (1, 250)}

Definition of done:

nabriis commented 2 months ago

@jakobsj DoD review and review the suggested change

jakobsj commented 2 months ago

@nabriis Sounds sensible, and many thanks for adding the example code. I think a brief discussion at daily would be useful to clarify. In general, this new class should seek to align as much as possible with Samples, as well as I suppose JointDistribution, such that the prepended "Joint" onto "Samples" and "Distribution" behaves in the same way, which appears from your suggestion to be OK. Which other aspects/alignments are there to consider - how about other functionality of Samples, such as plotting - could one align the behaviour there as well to generate the meaningful plots of each of the sets of samples?

nabriis commented 2 months ago

@jakobsj Plotting and other operations are usually done like this:

samples["x"].plot_ci()

I think this work fine. I think doing samples.plot_ci() would make too many plots. One could add more methods over time like Ns or compute_std from Samples object and return a list of stds for this joint object.

jakobsj commented 2 months ago

@nabriis But doesn't that assume that samples there is a dict, and the proposal is to change it to an instance of a new JointSamples class?

nabriis commented 2 months ago

@nabriis But doesn't that assume that samples there is a dict, and the proposal is to change it to an instance of a new JointSamples class?

It does assume the samples is a dict (or at least asssumes it has a key, value pair). The proposal is to subclass JointSamples from dict to maintain the same behavior.

jakobsj commented 2 months ago

Ah I see, I had missed that. Interesting. Maybe that does make a lot of sense. Somehow I had/would? expect it to be subclassed from Samples if it is to "behave like Samples" for example by having a burnthin method. Would multiple inheritance make sense here?

jakobsj commented 2 months ago

@nabriis DoD approved.