pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.44k stars 985 forks source link

Mass Matrix for NUTS and HMC #3208

Open ConnorStoneAstro opened 1 year ago

ConnorStoneAstro commented 1 year ago

Issue Description

I would like to be able to interact with the mass matrix in the NUTS and HMC samplers. In many cases I have access to the covariance matrix at the MAP, so being able to set the mass matrix exactly would provide a large speedup. Also, it would be nice to be able to access the mass matrix. This could be used for other purposes since it approximates the covariance matrix of the data.

This isn't a bug, just a feature I would like to have access to.

I was able to "hack" a way to get access, but it is not a long term solution since it involves overwritting functions:

def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}):
    """
    Sets up an initial mass matrix.

    :param dict mass_matrix_shape: a dict that maps tuples of site names to the shape of
        the corresponding mass matrix. Each tuple of site names corresponds to a block.
    :param bool adapt_mass_matrix: a flag to decide whether an adaptation scheme will be used.
    :param dict options: tensor options to construct the initial mass matrix.
    """
    inverse_mass_matrix = {}
    for site_names, shape in mass_matrix_shape.items():
        self._mass_matrix_size[site_names] = shape[0]
        diagonal = len(shape) == 1
        inverse_mass_matrix[site_names] = (
            torch.full(shape, self._init_scale, **options)
            if diagonal
            else torch.eye(*shape, **options) * self._init_scale
        )
        if adapt_mass_matrix:
            adapt_scheme = WelfordCovariance(diagonal=diagonal)
            self._adapt_scheme[site_names] = adapt_scheme

    if len(self.inverse_mass_matrix.keys()) == 0:
        self.inverse_mass_matrix = inverse_mass_matrix
BlockMassMatrix.configure = new_configure

Then later, once I had the mass matrix I could call:

nuts_kernel.mass_matrix_adapter.inverse_mass_matrix = {("x",): inv_mass}

I could probably make the change myself, but I imagine there is a more elegant way to do this.

fehiepsi commented 1 year ago

I think making a subclass BlockMassMatrix to return the desire configuration makes sense. Then you can set:

nuts_kernel.mass_matrix_adapter = ThatSubclass()

I guess we can also expose inverse_mass_matrix to the kernel construction, like what we have in numpyro.

ConnorStoneAstro commented 1 year ago

Subclassing BlockMassMatrix is more elegant than what I did, but it still has the issue that I have to do it myself and so any updates to Pyro could break something. Having it exposed to the user at construction would be great!