Closed tillahoffmann closed 5 years ago
Good one. Doesn't look like there's any way to use the software as it stands to do that. The best I have for you right now is that monte_carlo_csiszar_f_divergence is not a very complex function, so you could write your own variant.
For the future, I will edit this issue into a feature request to add the capability. If you feel like submitting a PR, we would welcome it!
@axch, thanks for the update. Do you have recommendations on how to perform variational inference using tensorflow probability for multivariate models?
I'm not actually very familiar with that suite of capabilities. Perhaps @jvdillon could comment?
Caveat: AFAIK this module was a very early addition to TFP and is probably in need of some dusting off and additional love.
I believe the commonly imagined use case was probably something like structured mean-field VI, where q would be a (block) diagonal MVN.
The main thing this method is good for is specification of one term in the elbo, in which we have
If you want to have richer structure in the variational distribution, q, you probably need do some scribbling of maths and then invoke this function multiple times, once for each E_q[log p(x | z)] term in your elbo.
Finally, note that monte_carlo_csiszar_f_divergence
does not compute the KL[q(z) || p(z)] term in the elbo, so you need to throw that in by hand. We have analytic KL's between many distributions out of the box; otherwise you can use monte carlo to sample a z from q(z) and compute log q - log p.
Hope this helps! Happy to try to help more, if you can give more details about your problem setup.
I'm gonna remove the 'good first issue' tag, because I'm fairly sure solving this completely will require some thoughtful and non-trivial API design discussion. Namely, it's a question of defining and aligning the true and variational model, when the structure and alignment are not as simple as, say, mean-field Gaussian.
This was addressed in 499827efa11b55f44fa0d5ef0432f3e1eebeff01 (and several changes leading up to that one). It's now possible to pass JointDistribution objects to tfp.vi.monte_carlo_variational_loss
, which is a renamed and updated version of monte_carlo_csiszar_f_divergence
; the updated code will be included in the upcoming TFP 0.8 release.
Closing this issue; feel free to reopen if needed.
Using
monte_carlo_csiszar_f_divergence
to perform inference for a single distribution works as expected (see example below).However, I couldn't figure how to make use of
monte_carlo_csiszar_f_divergence
for more complex models that involves different distributions (e.g. one normal, one gamma). Any advice would be much appreciated. #147 may be related?