probml / pyprobml

Python code for "Probabilistic Machine learning" book by Kevin Murphy
MIT License
6.54k stars 1.54k forks source link

make JAX implementation of SVI for mixture models #348

Closed murphyk closed 3 years ago

murphyk commented 3 years ago

Stochastic Variational Inference can be used to do online optimization of Bayesian/latent variable models, such as finite mixtures of Gaussians / Bernoullis - this is faster than batch EM for large datasets. Here is an implementation of SVI for GMMs in Tensorflow Probabiltity. The goals of this issue are:

  1. make a JAX translation of this code, using the new distrax library.
  2. generalize the code to work with mixtures of Bernoullis, apply it to the same binarized MNIST data used in mixBerMNIST.py
shivaditya-meduri commented 3 years ago

Hi @murphyk, I will start working on it!

shivaditya-meduri commented 3 years ago

Thank you for the link of "Bayesian Gaussian Mixture Modeling with Stochastic Variational Inference", it is very helpful

murphyk commented 3 years ago

This blog post has some of useful background on SVI for GMMs: https://ypei.me/posts/2019-02-14-raise-your-elbo.html. I will cover some of this in vol 2 of my book, but that's not ready for sharing yet.

shivaditya-meduri commented 3 years ago

Thank you @murphyk, I will start implementing SVI for BMM

Nirzu97 commented 3 years ago

@shivaditya-meduri Are you still working on this?

shivaditya-meduri commented 3 years ago

If you want to contribute that would be great, I will create a draft PR and you can look at the code I made till now

shivaditya-meduri commented 3 years ago

I have completed task 1, it is the task 2 I could use your help with, and after that I will add your name as a collaborator :). Thank you I wanted to upload only one commit for both the tasks, That's why I did not commit till now for the task 1