bilby-dev / bilby

A unified framework for stochastic sampling packages and gravitational-wave inference in Python.
https://bilby-dev.github.io/bilby/
MIT License
60 stars 71 forks source link

Added qmc_quad based method for estimation of the constrained normalization factor #839

Open JasperMartins opened 2 weeks ago

JasperMartins commented 2 weeks ago

This PR implements a new method to estimate the normalization factor for constrained priors. The changes are two-fold:

  1. The integration is performed with scipy.integrate.qmd_quad, a quasi-Monte Carlo-based integration routine that is expected to yield better results than regular Monte Carlo integration. However, the routine requires a rescaling step from the unit cube rather than direct sampling.
  2. The termination of the integration is based on its relative statistical error rather than the number of accepted samples.

I have tested the two implementations with a relatively easy scenario: A 2D uniform prior on the [-1,1] cube, constrained to an inscribed circle with different radii:

image

The new method is significantly faster for high normalization factors, and the relative errors show a similar spread.

The implementation is marked as a draft because of the requirement of a rescale-method of the priors. Thus, it could be nice to keep the old method as a fallback. Also, the relative-error termination criterion could be applied just as well to the old implementation.

Related issue: https://github.com/bilby-dev/bilby/issues/835

ColmTalbot commented 2 weeks ago

The implementation is marked as a draft because of the requirement of a rescale-method of the priors. Thus, it could be nice to keep the old method as a fallback. Also, the relative-error termination criterion could be applied just as well to the old implementation.

We currently have a (soft) requirement that all priors should implement a rescale method (currently it will just return None, which is not ideal, https://github.com/bilby-dev/bilby/blob/main/bilby/core/prior/base.py#L137-L153), so this approach should be safe.

Even if we make the base class raise an error when you attempt to rescale it will still be possible to use some samplers in that case.

It's possible that people will implement their own prior subclasses that don't support the rescale, so I'm not opposed to keeping a fallback. It may make everything easier if we change bilby.core.prior.base.Prior.rescale to raise a NotImplementedError, but that should probably get some more eyes and be done in a separate PR.

ColmTalbot commented 2 weeks ago

I think that actually the existing method won't work if the new prior doesn't implement rescale as sample. I think it's sufficiently unlikely that people are manually defining sample without rescale.

JasperMartins commented 1 week ago

I have updated the PR quite a bit. The core logic of the integration of the normalization factor is now handled by one of two functions: either MC-Integration based on samples from PriorDict.sample, or quasi MC-Integration based on the rescale method. The user can choose which is used, but the code also checks if the rescale method is implemented if qmc_quad is used and will default to from_samples if not. For both methods, termination of the integration is handled vi a bound of the estimated relative error. For both methods, a max_trials kwarg can be used to limit the number of probability evaluations.

I have also optimized the from_samples implementation. Before, every time min_accept was not reached, new samples were added to a list, and the constrained was applied to the full list - yielding a steep increase in runtime with the number of iterations while it would have been sufficient to check the new samples.

For the example I gave above, the qmc-based implementation is now actually slower than the from_samples method due to a higher overhead. Priors that implement sample by just calling rescale on unit-samples should perform much closer. I still selected qmc_quad as the default because, as the attached plot shows, for normalization factors close to 1 (which is more likely in most applications), the relative error is smaller.

I have also improved the robustness against bugs by checking if the chosen keys are sufficient to compute the constrained, and if the PriorDict is constrained in the first place.

image