Closed GStechschulte closed 1 year ago
Hi Gabriel,
thanks for opening the issue, and sorry for my delayed response. For some reason, I did not get a push notification for it 😅.
Regarding your first point: That's correct. The KL values are precomputed using an approximation, and then they can be indexed based on the chosen user values.
The second point is fascinating: Since I created the video and the corresponding visualization close to the start of the channel, I unfortunately do no longer have the script with which I created the KL values. It's also been some time since I last worked with TFP. There is a good chance, I made a mistake there. I tried to reproduce my values with the following MWE:
import tensorflow as tf
import tensorflow_probability as tfp
mixture_probs = [0.3, 0.4, 0.3]
mus = [-1.3, 2.2, 4.0]
sigmas = [2.3, 1.5, 4.4]
mu_approx = 0.0
sigma_approx = 1.0
sample_size = 10_000
true_posterior = tfp.distributions.MixtureSameFamily(
mixture_distribution=tfp.distributions.Categorical(probs=mixture_probs),
components_distribution=tfp.distributions.Normal(loc=mus, scale=sigmas)
)
surrogate_posterior = tfp.distributions.Normal(loc=mu_approx, scale=sigma_approx)
kl_approximated = tfp.vi.monte_carlo_variational_loss(
true_posterior.log_prob, surrogate_posterior, sample_size=sample_size
)
print(f"KL Approximated: {kl_approximated:1.3f}")
which prints something around 0.902
. That's different from the value I have in the streamlit visualization, and also different from your implementation; weird.
On the other hand, when I implement a "handwritten" KL approximation similar to your script
true_posterior_samples = true_posterior.sample(sample_size)
kl_approximated = tf.reduce_mean(
true_posterior.log_prob(true_posterior_samples)
-
surrogate_posterior.log_prob(true_posterior_samples)
)
that prints ~6
as you correctly note.
If I remember correctly, I must have used the tfp.vi.monte_carlo_variational_loss
function. So the mistake could probably be associated with that. I still have to do some more investigation and then update the visualization file.
Please let me know your thoughts, and thanks again for bringing up this important point! :)
P.S.: I will also look into the PR in a minute.
Hey @Ceyron I also apologize for the late response. I think it is resulting from the tfp.vi.monte_carlo_variational_loss
function, and the results are different possibly because there is no PRNG / seed set. However, the difference between the handwritten KL approximation and the monte carlo variational loss, ~6
and 0.902
is still interesting.
I am closing the issue now since we have both not worked on this problem in the past months.
In the elbo_interactive_plot.py script, the values in the variable KL_VALUES were computed using approximate KL-Divergence.
In the tutorial on approximating the KL-Divergence, scalar values were returned, and not multiple values as in the Python script linked above. It seems to me these are the pre-computed KL-Divergence values for every pair of
mu_approx
andsigma_approx
chosen by the user in the Streamilt app? To get the KL value associated with the users inputs, indexing is used. Is this thinking correct?I am implementing the material from the video Variational Inference | Evidence Lower Bound (ELBO) | Intuition & Visualization into a PyTorch version similar to elbo_interactive_plot file linked above. The initial implementation can be found here. One peculiarity I have found is the KL-Divergence value is
~6
when the initial parametersmu_approx = 0.
andsigma_approx = 1.
. Even with a large sample size, I figured this value would converge approximately to the solution in the video of~1.01
. However, this is not the case and all parameters for initializing the true and surrogate posterior are the same as the TensorFlow implementation. Any thoughts?I will be making a pull request into the
contrib
directory once finished.Thanks!