Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

Framework upgrade #132

Closed Joshuaalbert closed 2 months ago

Joshuaalbert commented 7 months ago

Is your feature request related to a problem? Please describe. As we move in a direction to blend deep learning with probabilistic programming we will update the frame to be more Module-like.

Describe the solution you'd like We currently have two types of variables, Bayesian and parametrised. Bayesian take on random values sampled from prior, and require sampling to realise them. Parametrised are point-wise estimates of prior RV, and do not require sampling. Both parameter types are use in several ways:

  1. In EvidenceMaximisation where the parametrised variables are optimised to maximise evidence w.r.t. Bayesian variables. The parametrised variables are constrained automatically by the prior distributions.
  2. In Bayesian neural networks where some variables are Bayesian.

Currently, model.params gives the parametrised variables, and model(params) is a pure function that produces a new model with the parameters. Then to get the likelihood input variables we call model.prepare_input(U), which takes some sample U.

We'd like to be able to define a separation between the generative prior side, and the likelihood side. It's then possible to chain generative components. Each generative component can have it's own parameters, of both types. We can imagine extracting just the parametrised variables of one component and maximising evidence over that.

We'll need to tackle this using a few key example problems to assess different interfaces.

Joshuaalbert commented 2 months ago

This has more or less been done as the framework has stablised, and been vetted as more projects use JAXNS.