Ceyron / machine-learning-and-simulation

All the handwritten notes 📝 and source code files 🖥️ used in my YouTube Videos on Machine Learning & Simulation (https://www.youtube.com/channel/UCh0P7KwJhuQ4vrzc3IRuw4Q)
MIT License
827 stars 177 forks source link

KL_VALUES variable in elbo_interactive_plot #8

Closed GStechschulte closed 1 year ago

GStechschulte commented 2 years ago

In the elbo_interactive_plot.py script, the values in the variable KL_VALUES were computed using approximate KL-Divergence.

In the tutorial on approximating the KL-Divergence, scalar values were returned, and not multiple values as in the Python script linked above. It seems to me these are the pre-computed KL-Divergence values for every pair of mu_approx and sigma_approx chosen by the user in the Streamilt app? To get the KL value associated with the users inputs, indexing is used. Is this thinking correct?

I am implementing the material from the video Variational Inference | Evidence Lower Bound (ELBO) | Intuition & Visualization into a PyTorch version similar to elbo_interactive_plot file linked above. The initial implementation can be found here. One peculiarity I have found is the KL-Divergence value is ~6 when the initial parameters mu_approx = 0. and sigma_approx = 1.. Even with a large sample size, I figured this value would converge approximately to the solution in the video of ~1.01. However, this is not the case and all parameters for initializing the true and surrogate posterior are the same as the TensorFlow implementation. Any thoughts?

I will be making a pull request into the contrib directory once finished.

Thanks!

Ceyron commented 2 years ago

Hi Gabriel,

thanks for opening the issue, and sorry for my delayed response. For some reason, I did not get a push notification for it 😅.

Regarding your first point: That's correct. The KL values are precomputed using an approximation, and then they can be indexed based on the chosen user values.

The second point is fascinating: Since I created the video and the corresponding visualization close to the start of the channel, I unfortunately do no longer have the script with which I created the KL values. It's also been some time since I last worked with TFP. There is a good chance, I made a mistake there. I tried to reproduce my values with the following MWE:

import tensorflow as tf
import tensorflow_probability as tfp

mixture_probs = [0.3, 0.4, 0.3]
mus = [-1.3, 2.2, 4.0]
sigmas = [2.3, 1.5, 4.4]

mu_approx = 0.0
sigma_approx = 1.0

sample_size = 10_000

true_posterior = tfp.distributions.MixtureSameFamily(
    mixture_distribution=tfp.distributions.Categorical(probs=mixture_probs),
    components_distribution=tfp.distributions.Normal(loc=mus, scale=sigmas)
)
surrogate_posterior = tfp.distributions.Normal(loc=mu_approx, scale=sigma_approx)

kl_approximated = tfp.vi.monte_carlo_variational_loss(
    true_posterior.log_prob, surrogate_posterior, sample_size=sample_size
)

print(f"KL Approximated: {kl_approximated:1.3f}")

which prints something around 0.902. That's different from the value I have in the streamlit visualization, and also different from your implementation; weird.

On the other hand, when I implement a "handwritten" KL approximation similar to your script

true_posterior_samples = true_posterior.sample(sample_size)
kl_approximated = tf.reduce_mean(
    true_posterior.log_prob(true_posterior_samples)
    -
    surrogate_posterior.log_prob(true_posterior_samples)
)

that prints ~6 as you correctly note.

If I remember correctly, I must have used the tfp.vi.monte_carlo_variational_loss function. So the mistake could probably be associated with that. I still have to do some more investigation and then update the visualization file.

Please let me know your thoughts, and thanks again for bringing up this important point! :)

P.S.: I will also look into the PR in a minute.

GStechschulte commented 1 year ago

Hey @Ceyron I also apologize for the late response. I think it is resulting from the tfp.vi.monte_carlo_variational_loss function, and the results are different possibly because there is no PRNG / seed set. However, the difference between the handwritten KL approximation and the monte carlo variational loss, ~6 and 0.902 is still interesting.

I am closing the issue now since we have both not worked on this problem in the past months.