probml / pyprobml

Python code for "Probabilistic Machine learning" book by Kevin Murphy
MIT License
6.54k stars 1.54k forks source link

figure out why ELBO does not improve with pymc3 ADVI demi #694

Closed murphyk closed 2 years ago

murphyk commented 2 years ago

The advi.hist (ELBO) should increase over time, but it does not seem to, even though the posterior looks sensible. Find out why. https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb

pat749 commented 2 years ago

hi @murphyk i would like to work on this issue and i implement this with using Stochastic gradient descent it works fine and main issue is in loss in which its does not take elbo and i see in this optimise has not been done also .

pat749 commented 2 years ago

hi @murphyk can you also give me some more reference for this problem .currently i am working on this .thanks

murphyk commented 2 years ago

@pat749 Maybe just try exponentially weighted moving average on the objective and see what the trend is. The odd thing is that it does not seem to be converging (perhaps it just needs more iterations?). Also the variance is surprisingly high (perhaps because of ADVI transformation?). Probably it would best to implement this method from scratch, using SVI and a suitable bijector?

Screen Shot 2022-03-21 at 2 45 26 PM
pat749 commented 2 years ago

yes i have started from scratch and currently i am doing using svi thanks .

pat749 commented 2 years ago

https://colab.research.google.com/drive/1AhWrR5YtGOw7h8gCVpIWXkBo2jeNXVz2 hi @murphyk can you review it ,i have done by my side .i have reviwed it by myside .

pat749 commented 2 years ago

here are paper from where i take help https://arxiv.org/abs/1802.02538 http://www.stat.columbia.edu/~gelman/research/unpublished/advi_journal.

karm-patel commented 2 years ago

Hi, Dr @murphyk, I have tried to implement the ADVI method from scratch in JAX (colab notebook). I have taken the feedback from Prof. @nipunbatra and @patel-zeel. Can you please give your comments on this?

Checklist

While writing the code I kept following things in mind,

Approach and Results

Here I'm presenting the brief summary of the approach and results I'm getting.

  1. Prior, Likelihood, and Posterior: Prior: theta ~ Beta(concentration1 = 3.0, concentration0 = 3.0) Likelihood: p(data/theta) ~ Bernoulli(probs=theta) Dataset: [1 1 1 0 0 0 0 0 0 1 0 1]

    image

    1. Negative ELBO for JAX and PYMC: I took proposal distribution as Normal distribution, transformation function as sigmoid, and then applied ADVI from the scratch. The below figure shows the plot of negative ELBO for both JAX and PYMC implementation. for pymc, I referred to the code from the same notebook: https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb
JAX PYMC
image image
  1. Posterior samples: Samples from learnt distribution q(x): After getting optimized parameters (loc and scale), I took 1000 samples from Normal distribution and transformed samples into Beta distribution using beta_sample = sigmoid(normal_sample). I observed that JAX implementation gives slightly better results than pymc. image I have used bandwidth adjustment (sns.kdeplot(...., bw_adjust=2)) to plot samples more smoother. (reference: issue-,Bandwidth%20adjustment,-Another%20thing%20to))

  2. Comparison of mean and standard deviation: image

References:

I have used the following references to accomplish this task.

  1. ADVI paper: https://arxiv.org/abs/1603.00788
  2. Code first ml blog: https://code-first-ml.github.io/book2/notebooks/introduction/variational.html
  3. blog: https://luiarthur.github.io/statorial/varinf/introvi/
  4. video: https://www.youtube.com/watch?v=HxQ94L8n0vU
murphyk commented 2 years ago

This looks great. But a few nits:

Screen Shot 2022-04-11 at 5 52 26 PM
karm-patel commented 2 years ago

Hi, Dr. @murphyk, Thank you for providing the feedback, I have addressed your comments as the following. (updated colab)

  • Your main fitting loop seems a bit slow - maybe try jitting the grad and Jacobian as well as the objective? Or using lax.scan?

So, here I tried the following things,

  1. Previously I was jitting only elbo_loss() function and grad_and_loss_fn remained un-jitted. So I jitted grad_and_loss_fn function, and by experimenting with various combinations I observed that by jitting parent function (grad_and_loss_fn), all functions called by it are automatically jitted. However, this did not improve the speed.

  2. Then I used lax.scan() as suggested by you. It reduced the training time from 100 seconds to 2.5 seconds (40x)!

  • Also in log_prior_likelihood_jacobian maybe clarify the math to say something like z = N(0,I) theta = sigmoid(z) objective = log Beta(theta) + log p(x|theta) + log |J(z)|

I have tried to improve this by jointly referring book1, book2, and ADVI paper.

  • Please add those references to your code.
  • There are some formatting issues with your latex, eg p(\frac{theta}{data}) should be p(\theta|x).
  • I would not call q a proposal distribution (since we are not doing importance sampling), I would just call it a variational distribution.

I have addressed all the above concerns.

  • Do you know why the pymc.ADVI elbo is so noisy?

I have tried digging in the pymc source code for this but could not have success figuring out the result.

Additionaly, I have fixed the following bug,

Old Figure Updated Figure
image image
murphyk commented 2 years ago

This is great, thanks! I will add it to probml-notebooks soon.

murphyk commented 2 years ago

@karm-patel I have checked in your code at https://github.com/probml/probml-notebooks/blob/main/notebooks/advi_beta_binom_jax.ipynb. I modified it to use the same dataset (10 heads, 1 tail) as https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb. The new figure for the book is attached, and looks much better.

Screen Shot 2022-04-13 at 3 30 56 PM

However there is still a problem. It looks like the approximate blue posterior q(theta) has support outside of the [0,1] range, but this is not true, as you can verify by looking at the (transformed) posterior samples. I think the problem is with this line

sns.kdeplot(transformed_samples, color="blue" ,label="$q(x)$ learned", bw_adjust=2)

Please figure out a way to ensure the KDE does not extrapolate past [0,1]. For example, figure out how arviz makes plots like this (ignore the HPDI line, but note that the support stops for <0).

Screen Shot 2022-04-13 at 3 33 48 PM
murphyk commented 2 years ago

I created https://github.com/probml/pyprobml/issues/738 (Laplace) and https://github.com/probml/pyprobml/issues/737 (HMC) to track the other two approximate inference examples. Then we will be liberated from PyMC :)

karm-patel commented 2 years ago

However there is still a problem. It looks like the approximate blue posterior q(theta) has support outside of the [0,1] range, but this is not true, as you can verify by looking at the (transformed) posterior samples. I think the problem is with this line

sns.kdeplot(transformed_samples, color="blue" ,label="$q(x)$ learned", bw_adjust=2)

Please figure out a way to ensure the KDE does not extrapolate past [0,1]. For example, figure out how arviz makes plots like this (ignore the HPDI line, but note that the support stops for <0).

Screen Shot 2022-04-13 at 3 33 48 PM

Thanks for pointing out this issue, I think we can use sns.kdeplot(....,clip = (0.0, 1.0)) to restrict plot in support.

So I have created a PR (probml/probml-notebooks#82) for following minor changes:

  1. I observed that after modifying the dataset, by mistake n_samples remained 12 instead of 11, due to this likelihood and posterior was not exactly matching. I have corrected this.
Old Updated
image image
  1. Also, plot of q(x) (learned) and p(x) (true posterior) were differing too much, this was due to the large value of bw_adjust=2 I have decreased bw_adjust from 2 to 1.5 and added clip=(0.0, 1.0).
sns.kdeplot(..., bw_adjust=2 ) - Old sns.kdeplot(..., bw_adjust=1.5, clip=(0.0,1.0) ) - Updated
image image

Please let me know if something is wrong or if you need more edits on this.