Closed murphyk closed 2 years ago
hi @murphyk i would like to work on this issue and i implement this with using Stochastic gradient descent it works fine and main issue is in loss in which its does not take elbo and i see in this optimise has not been done also .
hi @murphyk can you also give me some more reference for this problem .currently i am working on this .thanks
@pat749 Maybe just try exponentially weighted moving average on the objective and see what the trend is. The odd thing is that it does not seem to be converging (perhaps it just needs more iterations?). Also the variance is surprisingly high (perhaps because of ADVI transformation?). Probably it would best to implement this method from scratch, using SVI and a suitable bijector?
yes i have started from scratch and currently i am doing using svi thanks .
https://colab.research.google.com/drive/1AhWrR5YtGOw7h8gCVpIWXkBo2jeNXVz2 hi @murphyk can you review it ,i have done by my side .i have reviwed it by myside .
here are paper from where i take help https://arxiv.org/abs/1802.02538 http://www.stat.columbia.edu/~gelman/research/unpublished/advi_journal.
Hi, Dr @murphyk, I have tried to implement the ADVI method from scratch in JAX (colab notebook). I have taken the feedback from Prof. @nipunbatra and @patel-zeel. Can you please give your comments on this?
While writing the code I kept following things in mind,
jax.vmap()
wherever possible.Here I'm presenting the brief summary of the approach and results I'm getting.
Prior, Likelihood, and Posterior: Prior: theta ~ Beta(concentration1 = 3.0, concentration0 = 3.0) Likelihood: p(data/theta) ~ Bernoulli(probs=theta) Dataset: [1 1 1 0 0 0 0 0 0 1 0 1]
JAX | PYMC |
---|---|
Posterior samples: Samples from learnt distribution q(x)
: After getting optimized parameters (loc and scale), I took 1000 samples from Normal distribution and transformed samples into Beta distribution using beta_sample = sigmoid(normal_sample)
. I observed that JAX implementation gives slightly better results than pymc.
I have used bandwidth adjustment (sns.kdeplot(...., bw_adjust=2)
) to plot samples more smoother. (reference: issue-,Bandwidth%20adjustment,-Another%20thing%20to))
Comparison of mean and standard deviation:
I have used the following references to accomplish this task.
This looks great. But a few nits:
Please add those references to your code.
There are some formatting issues with your latex, eg p(\frac{theta}{data}) should be p(\theta|x).
Also in log_prior_likelihood_jacobian
maybe clarify the math to say something like
z = N(0,I)
theta = sigmoid(z)
objective = log Beta(theta) + log p(x|theta) + log |J(z)|
I would not call q a proposal distribution (since we are not doing importance sampling), I would just call it a variational distribution.
Your main fitting loop seems a bit slow - maybe try jitting the grad and Jacobian as well as the objective? Or using lax.scan?
Do you know why the pymc.ADVI elbo is so noisy?
Please check this advi demo in first, before working on the other JAX algorithms.
Hi, Dr. @murphyk, Thank you for providing the feedback, I have addressed your comments as the following. (updated colab)
- Your main fitting loop seems a bit slow - maybe try jitting the grad and Jacobian as well as the objective? Or using lax.scan?
So, here I tried the following things,
Previously I was jitting only elbo_loss()
function and grad_and_loss_fn
remained un-jitted. So I jitted grad_and_loss_fn
function, and by experimenting with various combinations I observed that by jitting parent function (grad_and_loss_fn
), all functions called by it are automatically jitted. However, this did not improve the speed.
Then I used lax.scan()
as suggested by you. It reduced the training time from 100 seconds to 2.5 seconds (40x)!
- Also in
log_prior_likelihood_jacobian
maybe clarify the math to say something likez = N(0,I) theta = sigmoid(z) objective = log Beta(theta) + log p(x|theta) + log |J(z)|
I have tried to improve this by jointly referring book1, book2, and ADVI paper.
- Please add those references to your code.
- There are some formatting issues with your latex, eg p(\frac{theta}{data}) should be p(\theta|x).
- I would not call q a proposal distribution (since we are not doing importance sampling), I would just call it a variational distribution.
I have addressed all the above concerns.
- Do you know why the pymc.ADVI elbo is so noisy?
I have tried digging in the pymc source code for this but could not have success figuring out the result.
Additionaly, I have fixed the following bug,
elbo_loss()
, which yielded, smooth loss previously. New figure looks as following,Old Figure | Updated Figure |
---|---|
This is great, thanks! I will add it to probml-notebooks soon.
@karm-patel I have checked in your code at https://github.com/probml/probml-notebooks/blob/main/notebooks/advi_beta_binom_jax.ipynb. I modified it to use the same dataset (10 heads, 1 tail) as https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb. The new figure for the book is attached, and looks much better.
However there is still a problem. It looks like the approximate blue posterior q(theta) has support outside of the [0,1] range, but this is not true, as you can verify by looking at the (transformed) posterior samples. I think the problem is with this line
sns.kdeplot(transformed_samples, color="blue" ,label="$q(x)$ learned", bw_adjust=2)
Please figure out a way to ensure the KDE does not extrapolate past [0,1]. For example, figure out how arviz makes plots like this (ignore the HPDI line, but note that the support stops for <0).
I created https://github.com/probml/pyprobml/issues/738 (Laplace) and https://github.com/probml/pyprobml/issues/737 (HMC) to track the other two approximate inference examples. Then we will be liberated from PyMC :)
However there is still a problem. It looks like the approximate blue posterior q(theta) has support outside of the [0,1] range, but this is not true, as you can verify by looking at the (transformed) posterior samples. I think the problem is with this line
sns.kdeplot(transformed_samples, color="blue" ,label="$q(x)$ learned", bw_adjust=2)
Please figure out a way to ensure the KDE does not extrapolate past [0,1]. For example, figure out how arviz makes plots like this (ignore the HPDI line, but note that the support stops for <0).
Thanks for pointing out this issue, I think we can use sns.kdeplot(....,clip = (0.0, 1.0))
to restrict plot in support.
So I have created a PR (probml/probml-notebooks#82) for following minor changes:
n_samples
remained 12 instead of 11, due to this likelihood and posterior was not exactly matching. I have corrected this.Old | Updated |
---|---|
bw_adjust=2
I have decreased bw_adjust
from 2 to 1.5 and added clip=(0.0, 1.0).sns.kdeplot(..., bw_adjust=2 ) - Old | sns.kdeplot(..., bw_adjust=1.5, clip=(0.0,1.0) ) - Updated |
---|---|
Please let me know if something is wrong or if you need more edits on this.
The advi.hist (ELBO) should increase over time, but it does not seem to, even though the posterior looks sensible. Find out why. https://github.com/probml/probml-notebooks/blob/main/notebooks/beta_binom_approx_post_pymc.ipynb