balbasty / nitorch

Neuroimaging in PyTorch
Other
83 stars 14 forks source link

Probabilistic programming #36

Open balbasty opened 3 years ago

balbasty commented 3 years ago

There is almost everything needed in pytorch to write a probabilistic programming (PP) library (see e.g. Pyro) and it looks like a good exercise. Plus, if it works well, it could form a nice basis to implement our inference algorithms. There's already a bit of PP in torch.distributions, where deterministic transformations can be applied to known distributions to create new distributions. However, it is mainly aimed at sampling and computing PDFs.

In this framework, my idea is to use RandomVariable objects as building blocks. A RandomVariable is a lot like (and is inspired by) a torch.Distribution: it has methods to sample, compute PDFs, CDFs, means, variances, etc. It also has a batch_shape and and event_shape to allow batching lots of similar random variables (think of a replication plate in a Bayesian network).

However, it also implements algebraic operations (sum, product, inverse) between random variables (or deterministic values) as well as more complex operations (exp, log, logdet, etc.). This is something that torch does not do. The idea is that if such an operation has a nice analytical result, the nice result is returned (the sum of two Gaussians is a Gaussian) whereas if no nice result is known, a slighlty dumber object is returned (e.g., a Sum of two random variables knows how to compute its expected value, but not its PDF). Another nice thing in this library is that parameters of a RandomVariable can be RandomVariables themselves, so it is easy to build generative models. Parameters are registered within the RandomVariable object, so the Bayesian network is implicitly stored (a bit like the computational graph is implicitly stored in torch). I have a function that uses this to check (conditional) independence between variables.

It's a bit of work because quite a few distributions must be implemented. The next step is to build "inference" constructors (from_ml, from_moments) and update functions for conjugate priors, although I am not sure yet what the nicest API would be. I also need to add KL divergences (that's in torch.distributions so we could copy stuff from here). Then we can probably have a variational framework where we define factors and a lower bound. I think that' what Pyro does (and then use autograd to optimise the posterior parameters).

I'll push a branch and link to this issue very soon.

JohnAshburner commented 2 years ago

Is there anything in https://github.com/spedemon/ilang that might be useful here? Author is local to the Martinos.

balbasty commented 2 years ago

Thanks for the pointer John! If you're interested, you may want to have a look at Pyro as well: https://pyro.ai Both libraries rely on some sort of MC sampling to estimate the evidence (or its lower bound).

When I started this thing I was hoping to have something that could sometimes (when the model is simple) compute analytical expected values instead. But I haven't touched it in a long time: https://github.com/balbasty/nitorch/tree/36-proba/nitorch/proba